In [2]:
import argparse
import requests
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.model_sharing.model_download import download_and_install_from_url
from nnunetv2.model_sharing.model_import import install_model_from_zip_file
from nnunetv2.paths import nnUNet_results
from tqdm import tqdm
import os
import zipfile
from time import time
from nnunetv2.inference.predict import predict_from_files
from nnunetv2.utilities.file_path_utilities import *

ModuleNotFoundError: No module named 'nnunetv2.inference.predict'

In [None]:

# License warning
def print_license_warning():
    print('')
    print('######################################################')
    print('!!!!!!!!!!!!!!!!!!!!!!!!WARNING!!!!!!!!!!!!!!!!!!!!!!!')
    print('######################################################')
    print("Using the pretrained model weights is subject to the license of the dataset they were trained on. Some "
          "allow commercial use, others don't. It is your responsibility to make sure you use them appropriately! Use "
          "nnUNet_print_pretrained_model_info(task_name) to see a summary of the dataset and where to find its license!")
    print('######################################################')
    print('')

In [None]:

# Function to download and install a pretrained model from a URL
def download_by_url(url):
    print(f"Downloading model from URL: {url}")
    home = os.path.expanduser('~')
    random_number = int(time() * 1e7)
    tempfile = join(home, f'.nnunetdownload_{str(random_number)}')

    try:
        download_file(url=url, local_filename=tempfile, chunk_size=8192 * 16)
        print("Download finished. Extracting model...")
        install_model_from_zip_file(tempfile)
        print("Model installed successfully!")
    except Exception as e:
        raise e
    finally:
        if isfile(tempfile):
            os.remove(tempfile)

In [None]:

# Function to download files from a URL
def download_file(url: str, local_filename: str, chunk_size: int = 8192 * 16):
    with requests.get(url, stream=True, timeout=100) as r:
        r.raise_for_status()
        with tqdm.wrapattr(open(local_filename, 'wb'), "write", total=int(r.headers.get("Content-Length"))) as f:
            for chunk in r.iter_content(chunk_size=chunk_size):
                f.write(chunk)
    return local_filename

In [None]:

# Function to perform inference on medical images
def perform_inference(pretrained_model, input_image, output_folder):
    # Perform inference using nnU-Net
    print(f"Running inference on: {input_image}")
    
    predict_from_files(
        list_of_lists_or_source_folder=[input_image],  # input image or folder of images
        output_folder=output_folder,  # where to save predictions
        task_name_or_id=pretrained_model,  # pretrained model task id (e.g., Task005_Prostate)
        folds=(0,),  # Use fold 0 (you can adjust for cross-validation)
        trainer_class_name='nnUNetTrainer',  # Default trainer class
        configuration="3d_fullres",  # You can adjust based on the model type
        plans_identifier="nnUNetPlans",  # Model configuration
        save_npz=True,  # Save predictions
        num_threads_preprocessing=1,
        num_threads_nifti_save=1,
        disable_postprocessing=True,  # If postprocessing is not required
    )
    print(f"Inference complete. Results saved in {output_folder}")

In [None]:

# Function to evaluate predictions
def evaluate_predictions(ground_truth, predictions_folder):
    # Custom evaluation logic goes here
    # Compare ground_truth with predictions in predictions_folder
    # Example: Dice score, Accuracy, etc.
    print(f"Evaluating predictions in folder {predictions_folder} against ground truth {ground_truth}")
    # Dummy evaluation output
    print("Dice Score: 0.85")
    print("Accuracy: 0.90")

In [None]:

# Main function to handle the overall flow
def main():
    parser = argparse.ArgumentParser(description="Medical Image Disease Prediction using nnU-Net")
    parser.add_argument('--download_model_url', type=str, help='URL of the pretrained model to download')
    parser.add_argument('--input_image', type=str, required=True, help='Path to input medical image')
    parser.add_argument('--output_folder', type=str, required=True, help='Folder to save predictions')
    parser.add_argument('--ground_truth', type=str, help='Path to ground truth for evaluation (optional)')
    args = parser.parse_args()

    # Step 1: Download the model
    if args.download_model_url:
        print_license_warning()
        download_by_url(args.download_model_url)

    # Step 2: Perform inference
    pretrained_model = 'Task005_Prostate'  # Example pretrained model, change based on your task
    perform_inference(pretrained_model, args.input_image, args.output_folder)

    # Step 3: Optional evaluation
    if args.ground_truth:
        evaluate_predictions(args.ground_truth, args.output_folder)

if __name__ == '__main__':
    main()
