/
resample_2022.py
130 lines (111 loc) · 4.88 KB
/
resample_2022.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from pathlib import Path
from multiprocessing import Pool
import logging
import click
import pandas as pd
import numpy as np
import SimpleITK as sitk
path_input_images = "data/imagesTr"
path_input_labels = "data/labelsTr"
path_output_images = "data/labelsTr_resampled"
path_output_labels = "data/labelsTr_resampled"
@click.command()
@click.argument('input_image_folder',
type=click.Path(exists=True),
default=path_input_images)
@click.argument('input_label_folder',
type=click.Path(exists=True),
default=path_input_labels)
@click.argument('output_image_folder', default=path_output_images)
@click.argument('output_label_folder', default=path_output_labels)
@click.option('--cores',
type=click.INT,
default=12,
help='The number of workers for parallelization.')
@click.option('--resampling',
type=click.FLOAT,
nargs=3,
default=(2, 2, 2),
help='Expect 3 positive floats describing the output '
'resolution of the resampling. To avoid resampling '
'on one or more dimension a value of -1 can be fed '
'e.g. --resampling 1.0 1.0 -1 will resample the x '
'and y axis at 1 mm/px and left the z axis untouched.')
def main(input_image_folder, input_label_folder, output_image_folder,
output_label_folder, cores, resampling):
""" This command line interface allows to resample NIFTI files within
the maximal bounding box covered by the field of view of both modalites
(PT and CT). The images are
resampled with spline interpolation
of degree 3 and the segmentation are resampled
by nearest neighbor interpolation.
INPUT_IMAGE_FOLDER is the path of the folder containing PT and CT images.
INPUT_LABEL_FOLDER is the path of the folder containing the labels.
OUTPUT_IMAGE_FOLDER is the path of the folder where to store the resampled PT and CT images.
OUTPUT_LABEL_FOLDER is the path of the folder where to store the resampled labels.
bounding boxes of each patient.
"""
logger = logging.getLogger(__name__)
logger.info('Resampling')
input_image_folder = Path(input_image_folder).resolve()
input_label_folder = Path(input_label_folder).resolve()
output_image_folder = Path(output_image_folder).resolve()
output_label_folder = Path(output_label_folder).resolve()
output_image_folder.mkdir(exist_ok=True)
output_label_folder.mkdir(exist_ok=True)
print('resampling is {}'.format(str(resampling)))
patient_list = [
f.name.split("__")[0] for f in input_image_folder.rglob("*_CT*")
]
if len(patient_list) == 0:
raise ValueError("No patient found in the input folder")
resampler = sitk.ResampleImageFilter()
resampler.SetOutputDirection([1, 0, 0, 0, 1, 0, 0, 0, 1])
resampler.SetOutputSpacing(resampling)
def resample_one_patient(p):
ct = sitk.ReadImage(
str([f for f in input_image_folder.rglob(p + "__CT*")][0]))
pt = sitk.ReadImage(
str([f for f in input_image_folder.rglob(p + "__PT*")][0]))
labels = [(sitk.ReadImage(str(f)), f.name)
for f in input_label_folder.glob(p + "*")]
bb = get_bouding_boxes(ct, pt)
size = np.round((bb[3:] - bb[:3]) / resampling).astype(int)
resampler.SetOutputOrigin(bb[:3])
resampler.SetSize([int(k) for k in size]) # sitk is so stupid
resampler.SetInterpolator(sitk.sitkBSpline)
ct = resampler.Execute(ct)
pt = resampler.Execute(pt)
sitk.WriteImage(ct, str((output_image_folder / (p + "__CT.nii.gz"))))
sitk.WriteImage(pt, str((output_image_folder / (p + "__PT.nii.gz"))))
resampler.SetInterpolator(sitk.sitkNearestNeighbor)
for label, name in labels:
label = resampler.Execute(label)
sitk.WriteImage(label, str((output_label_folder / name)))
for p in patient_list:
resample_one_patient(p)
# with Pool(cores) as p:
# p.map(resample_one_patient, patient_list)
def get_bouding_boxes(ct, pt):
"""
Get the bounding boxes of the CT and PT images.
This works since all images have the same direction
"""
ct_origin = np.array(ct.GetOrigin())
pt_origin = np.array(pt.GetOrigin())
ct_position_max = ct_origin + np.array(ct.GetSize()) * np.array(
ct.GetSpacing())
pt_position_max = pt_origin + np.array(pt.GetSize()) * np.array(
pt.GetSpacing())
return np.concatenate(
[
np.maximum(ct_origin, pt_origin),
np.minimum(ct_position_max, pt_position_max),
],
axis=0,
)
if __name__ == '__main__':
log_fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
logging.basicConfig(level=logging.INFO, format=log_fmt)
logging.captureWarnings(True)
main()