In [2]:
import h5py
import SimpleITK as sitk
import os
join = os.path.join
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import json
import nibabel as nib
import json
from tqdm import tqdm

## 从h5到h5

### 第一步：拆分h5文件

In [None]:
store_path = "/data3/home/lishengyong/data/ssc_3d/slices"
h5_path = "h5所在的文件夹"

for i in ["3D_images", "3D_masks", "jsons"]:
    os.makedirs(os.path.join(store_path, i), exist_ok=True)
    
def write(a, path, spacing=None):
    b = sitk.GetImageFromArray(a)
    if spacing is not None:
        spacing = spacing.tolist()
        b.SetSpacing(spacing=spacing)
    sitk.WriteImage(b, path)
    
def section(store_path, h5path):
    names = []    
    with h5py.File(h5path, "r") as file:
        for j in range(len(file["centers"])):
            name = f"{file['names'][j][0].decode('utf-8')}_{file['centers'][j]}"
            names.append(name)
            single_store = store_path + f"/3D_images/{name}.nii.gz"
            data_3D = file["datas"][j, 32:160, 32:160, :]
            Mask_3D = file["masks"][j, 32:160, 32:160, :]
            Mask_3D[Mask_3D != 1] = 0
            write(Mask_3D, single_store.replace("3D_images", "3D_masks"), file["spacings"][j])
            write(data_3D, single_store, file["spacings"][j])
    with open(join(store_path, "jsons", h5path.name[:-2]+"json")) as json_file:
        json.dump(names, json_file)
            
for i in Path(h5_path).glob("*.h5"):
    section(store_path, i)

### 第二步：推理
可以直接用 infer_3d_multibox.py 得到每个轴的结果  
运行方法：CUDA_VISIBLE_DEVICES=0 python infer_3d_multibox.py --axis {} --split {} --root {} --task_name{}  
axis是表示从哪个轴推理，0轴是XY平面  
split 默认为-1,表示不再拆分数据集，大概会占用3G多的显存，不然是0~4，表示数据集拆成五分，对应第几份，五份同时跑差不多吧72GCPU占满  
root就是上面的store_path，task_name就是h5文件名（不包括.h5后缀）  

infer_3d_multibox.py 得到第一轴的结果后，可以运行 iter_infer.py   
CUDA_VISIBLE_DEVICES=0 python iter_infer.py --min_iter {} --max_iter {} --split {} --root {} --task_name{}  
从上述1轴的结果产生2轴的结果，表示1个iter，如果是1（已有）--2--0顺序，那min_iter=2,max_iter=4  
如果是1--2--0--1--2--0顺序，那min_iter=2,max_iter=7，确保iter % 3 是对应的轴  
其余跟上面一样  

#### 产生的embadding文件如果之后用不到可以用
find root_path -type d -name "embadding" -exec rm -r {} \;清理

## 第三步：后处理
### open操作(先腐蚀后膨胀去除小像素点)

In [None]:
from utils.utils import opening
open_store_path = "保存路径"
os.makedirs(open_store_path, exist_ok=True)
opening("推理结果的路径，比如某个iter文件夹下面的infer文件夹", open_store_path)

### 迭代后的结果加回原来的结果

In [None]:
with open("/data3/home/lishengyong/data/ssc_3d/slices/jsons/lndb_weight.json", "r") as file:
    cases = json.load(file)

direct_path = "/data3/home/lishengyong/data/ssc_3d/slices/infer_3d_iter/axis_0/0818_98/multi_box"
iter_path = "/data3/home/lishengyong/data/ssc_3d/slices/infer_3d_iter/axis_1/0818_98/iter9/opening"
store_path = "/data3/home/lishengyong/data/ssc_3d/slices/infer_3d_iter/axis_0/0818_98/jiehe"
os.makedirs(store_path, exist_ok=True)

for case in cases:   
    case = case + ".nii.gz"
    direct_pred = nib.load(join(direct_path, case)).get_fdata() 
    direct_pred[direct_pred == 1] = 0
    direct_pred[direct_pred == 2] = 1
    iter_pred = nib.load(join(direct_path, case)).get_fdata() 
    has_mask = [index for index, arr in enumerate(direct_pred) if np.max(arr) != 0]
    for i in has_mask:
        iter_pred[i] = direct_pred[i]
    nib.save(nib.Nifti1Image(iter_pred, None), join(store_path, case))


## 第四步 拼回H5文件

In [2]:
pred_path = "/data3/home/lishengyong/data/ssc_3d/slices/infer_3d_iter/axis_0/0818_98/jiehe"
save_path = "/data3/home/lishengyong/data/ssc_3d/slices/infer_3d_iter/lndb_weigh.h5"

with open("/data3/home/lishengyong/data/ssc_3d/slices/jsons/lndb_weight.json", "r") as file:
    cases = json.load(file)
print(len(cases))
    
preds = []
for case in tqdm(cases, total=len(cases)):
    case = case+".nii.gz"
    pred = nib.load(join(pred_path, case))
    pred_arr = pred.get_fdata()
    pred_arr = np.transpose(pred_arr, (2,1,0))
    # 根据情况来pad
    pred_arr = np.pad(pred_arr, ((32,32), (32,32), (0,0)), mode="constant", constant_values=0).astype(int)
    if np.max(pred_arr) == 0:
        print(case)
    preds.append(pred_arr)

pred_arr = np.stack(preds)
print(pred_arr.shape)
with h5py.File(save_path, "w") as file:
    file["masks"] = pred_arr
    
# pred_img = nib.Nifti1Image(pred_arr, None)
# nib.save(pred_img, save_path)

1154


100%|██████████| 1154/1154 [00:33<00:00, 34.15it/s]


(1154, 192, 192, 64)


In [4]:
a = nib.load("/data3/home/lishengyong/data/ssc_3d/slices/infer_3d_iter/lndb_weigh.nii.gz").get_fdata()
save_path = "/data3/home/lishengyong/data/ssc_3d/slices/infer_3d_iter/lndb_weigh.h5"
a = a.astype(bool)
print(a[1,1,...])
with h5py.File(save_path, "w") as file:
    file["masks"] = a

[[False False False ... False False False]
 [False False False ... False False False]
 [False False False ... False False False]
 ...
 [False False False ... False False False]
 [False False False ... False False False]
 [False False False ... False False False]]
