In [1]:
import os
import torch

from glob import glob
from monai.data import Dataset, DataLoader, decollate_batch
from monai.inferers import sliding_window_inference
from monai.transforms import (
    Spacingd,
    AsDiscreted,
    Compose,
    EnsureChannelFirstd,
    Invertd,
    LoadImaged,
    Orientationd,
    ScaleIntensityd,
    KeepLargestConnectedComponentd,
    SaveImaged,
    NormalizeIntensityd
)
from monai.networks.layers.factories import Act, Norm
from tqdm import tqdm

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
from monai.networks.nets import UNet
from prostate158.network.unetr import UNETR
from monai.networks.nets import SwinUNETR
from prostate158.network.nnFormer.nnFormer_seg import nnFormer
from prostate158.network.UXNet_3D.network_backbone import UXNET
from prostate158.network.mixformer.mixing_unetr_qkv import MixingUNETR 

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def get_model(network):

    # mixunetr系列
    if network == "mixunetr":
        return MixingUNETR(
                img_size=(96, 96, 96),
                depths=[2,2,2,2],
                in_channels=1,
                out_channels=3,
                feature_size=48,
            )

    
    # unetr系列
    elif network == "unetr":
        return UNETR(
                in_channels=1,
                out_channels=3,
                img_size=(96,96,96),
                feature_size=48,
                hidden_size=768,
                mlp_dim=3072,
                num_heads=12,
                pos_embed="perceptron",
                norm_name="instance",
                conv_block=True,
                res_block=True,
                dropout_rate=0.0,
            )
    elif network == '3dunet':
        return UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=3,
            channels=[16, 32, 64, 128, 256, 512],
            strides=[2, 2, 2, 2, 2],
            num_res_units=0,
            act='PRELU',
            norm='BATCH',
            dropout=0.15,
                )
    elif network == '3dresunet':
        return UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=3,
            channels=[16, 32, 64, 128, 256, 512],
            strides=[2, 2, 2, 2, 2],
            num_res_units=4,
            act='PRELU',
            norm='BATCH',
            dropout=0.15,
                )
    # nnformer系列
    elif network == "nnformer":
        return nnFormer(
            input_channels=1,
            num_classes=3,
            )
    
    elif network == "swinunetr":
        return SwinUNETR(
                in_channels=1,
                out_channels=3,
                depths=[2,2,2,2],
                img_size=(96,96,96),
                feature_size=48,
            )        
    elif network == "uxnet":
        return UXNET(
                in_chans=1,
                out_chans=3,
                depths=[2, 2, 2, 2],
                feat_size=[48, 96, 192, 384],
                drop_path_rate=0,
                layer_scale_init_value=1e-6,
                spatial_dims=3,
            )

In [5]:
def main(tempdir, net, output_dir= ''):

    images = sorted(glob(os.path.join(tempdir, "*.nii.gz")))
    files = [{"img": img} for img in images]

    # define pre transforms
    pre_transforms = Compose(
        [
            LoadImaged(keys="img"),
            EnsureChannelFirstd(keys="img"),
            Orientationd(keys="img", axcodes="RAS"),
            Spacingd(keys="img", pixdim=[0.5, 0.5, 0.5]),
            ScaleIntensityd(keys="img", minv=0, maxv=1),
            NormalizeIntensityd(keys="img"),
        ]
    )
    # define dataset and dataloader
    dataset = Dataset(data=files, transform=pre_transforms)
    dataloader = DataLoader(dataset, batch_size=1, num_workers=4)
    # define post transforms
    post_transforms = Compose(
        [
            AsDiscreted(
                keys="pred", 
                argmax=True, 
                # to_onehot=3, 
                num_classes=3
            ),
            KeepLargestConnectedComponentd(
                keys="pred", 
                applied_labels=list(range(1, 3))
            ),
            Invertd(
                keys="pred",  # invert the `pred` data field, also support multiple fields
                transform=pre_transforms,
                orig_keys="img",  # get the previously applied pre_transforms information on the `img` data field,
                # then invert `pred` based on this information. we can use same info
                # for multiple fields, also support different orig_keys for different fields
                nearest_interp=True,  # don't change the interpolation mode to "nearest" when inverting transforms
                # to ensure a smooth output, then execute `AsDiscreted` transform
                to_tensor=True,  # convert to PyTorch Tensor after inverting
            ),
            SaveImaged(keys="pred", output_dir=output_dir, output_postfix="seg", resample=False),
        ]
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # net = AW_UNet(
    #     spatial_dims=3,
    #     in_channels=1,
    #     out_channels=3,
    #     channels=[16, 32, 64, 128, 256, 512],
    #     strides=[2, 2, 2, 2, 2],
    #     num_res_units=4,
    #     act=Act.PRELU,
    #     norm=Norm.BATCH,
    #     dropout=0.15,
    #     use_checkpoint=False
    # ).to(device)
    # net.load_state_dict(torch.load("/home/data/wan/prostate158/models/network_pp_aw_dou_seed0_1_key_metric=0.8198.pt"))


    net.eval()
    with torch.no_grad():
        for d in tqdm(dataloader):
            images = d["img"].to(device)
            # define sliding window size and batch size for windows inference
            d["pred"] = sliding_window_inference(inputs=images, roi_size=(96, 96, 96), sw_batch_size=4, overlap=0.5, predictor=net)
            # decollate the batch data into a list of dictionaries, then execute postprocessing transforms
            d = [post_transforms(i) for i in decollate_batch(d)]

In [8]:
network= "mixunetr"
model = get_model(network).to(device)

model.load_state_dict(torch.load("/home/data1/skyous/prostate158_log/checkpoint/mixunetr_qkv_0.8253.pt"))

tempdir = "/home/skyous/git/prostate158/vision/001_img"
output_dir = os.path.join('/home/skyous/git/prostate158/vision/pred/', network)
main(tempdir, model, output_dir)

  0%|          | 0/19 [00:00<?, ?it/s]

2024-03-29 09:05:35,756 INFO image_writer.py:197 - writing: /home/skyous/git/prostate158/vision/pred/mixunetr/image_001/image_001_seg.nii.gz


  5%|▌         | 1/19 [00:08<02:40,  8.94s/it]

2024-03-29 09:05:38,647 INFO image_writer.py:197 - writing: /home/skyous/git/prostate158/vision/pred/mixunetr/image_002/image_002_seg.nii.gz


 11%|█         | 2/19 [00:11<01:32,  5.45s/it]

2024-03-29 09:05:42,031 INFO image_writer.py:197 - writing: /home/skyous/git/prostate158/vision/pred/mixunetr/image_003/image_003_seg.nii.gz


 16%|█▌        | 3/19 [00:15<01:11,  4.49s/it]

2024-03-29 09:05:46,230 INFO image_writer.py:197 - writing: /home/skyous/git/prostate158/vision/pred/mixunetr/image_004/image_004_seg.nii.gz


 21%|██        | 4/19 [00:19<01:05,  4.38s/it]

2024-03-29 09:05:50,595 INFO image_writer.py:197 - writing: /home/skyous/git/prostate158/vision/pred/mixunetr/image_005/image_005_seg.nii.gz


 26%|██▋       | 5/19 [00:23<01:00,  4.35s/it]

2024-03-29 09:05:55,047 INFO image_writer.py:197 - writing: /home/skyous/git/prostate158/vision/pred/mixunetr/image_006/image_006_seg.nii.gz


 32%|███▏      | 6/19 [00:28<00:57,  4.40s/it]

2024-03-29 09:05:59,927 INFO image_writer.py:197 - writing: /home/skyous/git/prostate158/vision/pred/mixunetr/image_007/image_007_seg.nii.gz


 37%|███▋      | 7/19 [00:33<00:55,  4.61s/it]

2024-03-29 09:06:04,332 INFO image_writer.py:197 - writing: /home/skyous/git/prostate158/vision/pred/mixunetr/image_008/image_008_seg.nii.gz


 42%|████▏     | 8/19 [00:37<00:49,  4.51s/it]

2024-03-29 09:06:08,738 INFO image_writer.py:197 - writing: /home/skyous/git/prostate158/vision/pred/mixunetr/image_009/image_009_seg.nii.gz


 47%|████▋     | 9/19 [00:41<00:44,  4.45s/it]

2024-03-29 09:06:13,320 INFO image_writer.py:197 - writing: /home/skyous/git/prostate158/vision/pred/mixunetr/image_010/image_010_seg.nii.gz


 53%|█████▎    | 10/19 [00:46<00:40,  4.52s/it]

2024-03-29 09:06:18,045 INFO image_writer.py:197 - writing: /home/skyous/git/prostate158/vision/pred/mixunetr/image_011/image_011_seg.nii.gz


 58%|█████▊    | 11/19 [00:51<00:36,  4.54s/it]

2024-03-29 09:06:21,139 INFO image_writer.py:197 - writing: /home/skyous/git/prostate158/vision/pred/mixunetr/image_012/image_012_seg.nii.gz


 63%|██████▎   | 12/19 [00:54<00:28,  4.13s/it]

2024-03-29 09:06:24,208 INFO image_writer.py:197 - writing: /home/skyous/git/prostate158/vision/pred/mixunetr/image_013/image_013_seg.nii.gz


 68%|██████▊   | 13/19 [00:57<00:22,  3.80s/it]

2024-03-29 09:06:28,787 INFO image_writer.py:197 - writing: /home/skyous/git/prostate158/vision/pred/mixunetr/image_014/image_014_seg.nii.gz


 74%|███████▎  | 14/19 [01:02<00:20,  4.08s/it]

2024-03-29 09:06:31,977 INFO image_writer.py:197 - writing: /home/skyous/git/prostate158/vision/pred/mixunetr/image_015/image_015_seg.nii.gz


 79%|███████▉  | 15/19 [01:05<00:15,  3.77s/it]

2024-03-29 09:06:36,762 INFO image_writer.py:197 - writing: /home/skyous/git/prostate158/vision/pred/mixunetr/image_016/image_016_seg.nii.gz


 84%|████████▍ | 16/19 [01:10<00:12,  4.09s/it]

2024-03-29 09:06:41,189 INFO image_writer.py:197 - writing: /home/skyous/git/prostate158/vision/pred/mixunetr/image_017/image_017_seg.nii.gz


 89%|████████▉ | 17/19 [01:14<00:08,  4.19s/it]

2024-03-29 09:06:45,838 INFO image_writer.py:197 - writing: /home/skyous/git/prostate158/vision/pred/mixunetr/image_018/image_018_seg.nii.gz


 95%|█████████▍| 18/19 [01:19<00:04,  4.32s/it]

2024-03-29 09:06:50,245 INFO image_writer.py:197 - writing: /home/skyous/git/prostate158/vision/pred/mixunetr/image_019/image_019_seg.nii.gz


100%|██████████| 19/19 [01:23<00:00,  4.40s/it]


In [4]:
import shutil
import os
for root, dirs, files in os.walk('/home/data1/skyous/prostate158/test'):
    for file in files:
        if file.endswith("t2.nii.gz"):
            file_path = os.path.join(root, file)
            dst_path = os.path.join('/home/skyous/git/prostate158/vision/img', 'image_'+file_path.split('/')[-2]+".nii.gz")
            print(dst_path)
            shutil.copy(file_path, dst_path)
            # file_list.append(file_path)



/home/skyous/git/prostate158/vision/img/image_005.nii.gz
/home/skyous/git/prostate158/vision/img/image_009.nii.gz
/home/skyous/git/prostate158/vision/img/image_012.nii.gz
/home/skyous/git/prostate158/vision/img/image_008.nii.gz
/home/skyous/git/prostate158/vision/img/image_014.nii.gz
/home/skyous/git/prostate158/vision/img/image_016.nii.gz
/home/skyous/git/prostate158/vision/img/image_015.nii.gz
/home/skyous/git/prostate158/vision/img/image_007.nii.gz
/home/skyous/git/prostate158/vision/img/image_019.nii.gz
/home/skyous/git/prostate158/vision/img/image_013.nii.gz
/home/skyous/git/prostate158/vision/img/image_011.nii.gz
/home/skyous/git/prostate158/vision/img/image_017.nii.gz
/home/skyous/git/prostate158/vision/img/image_006.nii.gz
/home/skyous/git/prostate158/vision/img/image_010.nii.gz
/home/skyous/git/prostate158/vision/img/image_003.nii.gz
/home/skyous/git/prostate158/vision/img/image_001.nii.gz
/home/skyous/git/prostate158/vision/img/image_002.nii.gz
/home/skyous/git/prostate158/vi

In [5]:
import shutil
import os
for root, dirs, files in os.walk('/home/data1/skyous/prostate158/test'):
    for file in files:
        if file.endswith("t2_anatomy_reader1.nii.gz"):
            file_path = os.path.join(root, file)
            dst_path = os.path.join('/home/skyous/git/prostate158/vision/label', 'label_'+file_path.split('/')[-2]+".nii.gz")
            print(dst_path)
            shutil.copy(file_path, dst_path)

/home/skyous/git/prostate158/vision/label/label_005.nii.gz
/home/skyous/git/prostate158/vision/label/label_009.nii.gz
/home/skyous/git/prostate158/vision/label/label_012.nii.gz
/home/skyous/git/prostate158/vision/label/label_008.nii.gz
/home/skyous/git/prostate158/vision/label/label_014.nii.gz
/home/skyous/git/prostate158/vision/label/label_016.nii.gz
/home/skyous/git/prostate158/vision/label/label_015.nii.gz
/home/skyous/git/prostate158/vision/label/label_007.nii.gz
/home/skyous/git/prostate158/vision/label/label_019.nii.gz
/home/skyous/git/prostate158/vision/label/label_013.nii.gz
/home/skyous/git/prostate158/vision/label/label_011.nii.gz
/home/skyous/git/prostate158/vision/label/label_017.nii.gz
/home/skyous/git/prostate158/vision/label/label_006.nii.gz
/home/skyous/git/prostate158/vision/label/label_010.nii.gz
/home/skyous/git/prostate158/vision/label/label_003.nii.gz
/home/skyous/git/prostate158/vision/label/label_001.nii.gz
/home/skyous/git/prostate158/vision/label/label_002.nii.

# no use

In [23]:
import SimpleITK as sitk

In [35]:


# 读取图像和标签
image = sitk.ReadImage("/home/data1/skyous/SouthHP_prostate/sort_crop/images/image_113.nii.gz")
label = sitk.ReadImage("/home/data1/skyous/SouthHP_prostate/sort_crop/labels/label_113.nii.gz")
print("Image Direction:", image.GetDirection())
print("Label Direction:", label.GetDirection())

# 调整图像和标签的方向
# flipped_image = sitk.Flip(image, [True, True, False])  # 在x和y轴上进行翻转
flipped_label = sitk.Flip(label, [True, False, True])  # 在x和y轴上进行翻转

# 可视化调整后的图像和标签
# sitk.Show(flipped_image, title="Flipped Image")
# sitk.Show(flipped_label, title="Flipped Label")
sitk.WriteImage(flipped_label, "./predict/debug/label_113.nii.gz")


Image Direction: (0.9997967872407916, -0.018594690540259643, 0.007785999245346475, 0.02014630672171616, 0.9079458553698535, -0.41860296587128026, 0.0007145267425893094, 0.4186747678095934, 0.9081359673416306)
Label Direction: (0.9997968500287552, -0.018591595841661694, 0.007785319692728036, 0.020143214590042282, 0.9079459360713122, -0.41860294294653744, 0.0007138477137779891, 0.4186747302329682, 0.9081359837346937)


In [26]:
sitk.WriteImage(flipped_label, "./predict/debug/label_113.nii.gz")


In [46]:

import shutil
order_list = [6,8,15,22,23,33,72,103,122,113,35,47,54,55,60,66]



for num in order_list:
    order_label_name = "label_" + str(num).zfill(3) + ".nii.gz"
    order_image_name = "image_" + str(num).zfill(3) + ".nii.gz"
    label_path = os.path.join("/home/data1/skyous/SouthHP_prostate/sort_crop/labels", order_label_name)
    image_path = os.path.join("/home/data1/skyous/SouthHP_prostate/sort_crop/images", order_image_name)

    label = sitk.ReadImage(label_path)
    flipped_label = sitk.Flip(label, [True, False, True])

    os.makedirs("./predict/debug/newlabel/", exist_ok=True)
    os.makedirs("./predict/debug/newimage/", exist_ok=True)

    sitk.WriteImage(flipped_label, os.path.join("./predict/debug/newlabel/", order_label_name))

    new_image_path = os.path.join("./predict/debug/newimage/", order_image_name)
    shutil.copy(image_path, new_image_path)
    

    

# glob(os.path.join(tempdir, "*.nii.gz"))

In [49]:
import os
order_list = [6,8,15,22,23,33,72,103,122,113,35,47,54,55,60,66]

order_list = [str(num).zfill(3) for num in order_list]
order_list_str = ', '.join(order_list)
print(order_list_str)

# Define the directory path
directory = "/home/data1/skyous/SouthHP_prostate/sort_crop_v2/labels"

# Iterate over the files in the directory
for filename in os.listdir(directory):
    # Check if the filename contains any of the strings in order_list_str
    if any(substr in filename for substr in order_list_str.split(", ")):
        # Delete the file
        os.remove(os.path.join(directory, filename))



006, 008, 015, 022, 023, 033, 072, 103, 122, 113, 035, 047, 054, 055, 060, 066


In [1]:
import pandas as pd
order_list = [6,8,15,22,23,33,72,103,122,113,35,47,54,55,60,66]
# Read the CSV file
df = pd.read_csv("southHP_predict/all_v3.csv")

# Create a list to hold the rows to be deleted
rows_to_delete = []

# Iterate over each row
for index, row in df.iterrows():
    # Check if the ID exists in the order_list
    if row['ID'] in order_list:
        print(f"Deleting row with ID: {row['ID']}", index)
        # Add the index to the list
        rows_to_delete.append(index)

# Delete the rows
df.drop(rows_to_delete, inplace=True)

# Save the modified dataframe to a new CSV file
df.to_csv("southHP_predict/all_v4.csv", index=False)

Deleting row with ID: 6 6
Deleting row with ID: 8 8
Deleting row with ID: 15 15
Deleting row with ID: 22 22
Deleting row with ID: 23 23
Deleting row with ID: 33 33
Deleting row with ID: 35 35
Deleting row with ID: 47 47
Deleting row with ID: 54 54
Deleting row with ID: 55 55
Deleting row with ID: 60 60
Deleting row with ID: 66 66
Deleting row with ID: 72 72
Deleting row with ID: 103 103
Deleting row with ID: 113 113
Deleting row with ID: 122 122
