-
Notifications
You must be signed in to change notification settings - Fork 46
/
acdc_3d.py
114 lines (73 loc) · 3.45 KB
/
acdc_3d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import numpy as np
import SimpleITK as sitk
from utils import ResampleXYZAxis, ResampleFullImageToRef
import os
import random
import yaml
def ResampleCMRImage(imImage, imLabel, save_path, patient_name, count, target_spacing=(1., 1., 1.)):
assert imImage.GetSpacing() == imLabel.GetSpacing()
assert imImage.GetSize() == imLabel.GetSize()
spacing = imImage.GetSpacing()
origin = imImage.GetOrigin()
npimg = sitk.GetArrayFromImage(imImage)
nplab = sitk.GetArrayFromImage(imLabel)
z, y, x = npimg.shape
if not os.path.exists('%s'%(save_path)):
os.mkdir('%s'%(save_path))
tmp_img = npimg
tmp_lab = nplab
tmp_itkimg = sitk.GetImageFromArray(tmp_img)
tmp_itkimg.SetSpacing(spacing[0:3])
tmp_itkimg.SetOrigin(origin[0:3])
tmp_itkimg.SetDirection((1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0))
tmp_itklab = sitk.GetImageFromArray(tmp_lab)
tmp_itklab.SetSpacing(spacing[0:3])
tmp_itklab.SetOrigin(origin[0:3])
tmp_itklab.SetDirection((1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0))
#re_img = ResampleXYZAxis(tmp_itkimg, space=(target_space[0], target_space[1], target_space[2]))
#re_lab = ResampleFullImageToRef(tmp_itklab, re_img)
# interp xy plane using Bilinear, while interp z axis using NearestNeighbor. Follow nnUNet
re_img_xy = ResampleXYZAxis(tmp_itkimg, space=(target_spacing[0], target_spacing[1], spacing[2]), interp=sitk.sitkBSpline)
re_lab_xy = ResampleFullImageToRef(tmp_itklab, re_img_xy)
re_img_xyz = ResampleXYZAxis(re_img_xy, space=(target_spacing[0], target_spacing[1], target_spacing[2]), interp=sitk.sitkNearestNeighbor)
re_lab_xyz = ResampleFullImageToRef(re_lab_xy, re_img_xyz)
sitk.WriteImage(re_img_xyz, '%s/%s_%d.nii.gz'%(save_path, patient_name, count))
sitk.WriteImage(re_lab_xyz, '%s/%s_%d_gt.nii.gz'%(save_path, patient_name, count))
if __name__ == '__main__':
src_path = '/research/cbim/medical/medical-share/public/ACDC/raw/training/'
tgt_path = '/research/cbim/medical/yg397/ACDC_3d/'
# This is to align train/val/test split with TransUNet SwinUNet and etc.
patient_list = list(range(1, 101))
#val_list = [89, 90, 91, 93, 94, 96, 97, 98, 99, 100]
#test_list = [2, 3, 8, 9, 12, 14, 17, 24, 42, 48, 49, 53, 55, 64, 67, 79, 81, 88, 92, 95]
#train_list = list(set(patient_list) - set(val_list) - set(test_list))
# If don't want to align with them, just use following code to random split
'''
patient_list = list(range(1, 101))
random.seed(0)
random.shuffle(patient_list)
train_list = patient_list[:70]
val_list = patient_list[70:80]
test_list = patient_list[80:]
'''
name_list = []
for idx in patient_list:
name_list.append('patient%.3d'%idx)
os.mkdir('%slist'%(tgt_path))
with open("%slist/dataset.yaml"%tgt_path, "w",encoding="utf-8") as f:
yaml.dump(name_list, f)
os.chdir(src_path)
for name in os.listdir('.'):
os.chdir(name)
count = 0
for i in os.listdir('.'):
if 'gt' in i:
tmp = i.split('_')
img_name = tmp[0] + '_' + tmp[1]
patient_name = tmp[0]
img = sitk.ReadImage('%s.nii.gz'%img_name)
lab = sitk.ReadImage('%s_gt.nii.gz'%img_name)
ResampleCMRImage(img, lab, tgt_path, patient_name, count, (1.5625, 1.5625, 5.0))
count += 1
print(name, 'done')
os.chdir('..')