-
Notifications
You must be signed in to change notification settings - Fork 2
/
cathetergen.py
120 lines (94 loc) · 4.1 KB
/
cathetergen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5"
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from utils import cartesian_product
import skeleton2
import utils
import visdom
def transform(img):
max_value = img.max()
img_tensor = torch.from_numpy(img.astype(np.int32))
img_tensor = img_tensor.float() / max_value
return img_tensor
class SegData(Dataset):
def __init__(self, root, train, transform):
"""
:param root: the path of data
:param transform: transforms to make the output tensor
"""
self.root = root
self.s = 1
self.dlist = [os.path.join(self.root, x) for x in os.listdir(root)]
self.transform = transform
self.zeros = np.array([0], dtype=np.float64).reshape(-1)
self.rotation = np.array([np.linspace(-30, 30, 15)] * self.s).reshape(-1)
self.translation = np.array([np.linspace(-5, 5, 3)] * self.s).reshape(-1)
self.label = cartesian_product(self.zeros, self.zeros, self.rotation, self.zeros, self.zeros, self.zeros)
self.CT = []
# self.drr_win = None
# self.vis = visdom.Visdom()
# self.num_samples = len(self.dlist)
if train:
file = open('train_zzz.csv', 'w')
else:
file = open('test_zzz.csv', 'w')
for f in self.dlist:
# path = os.path.join(f, 'xray')
# if not os.path.isdir(path):
# os.mkdir(path)
CT = os.path.join(f, 'numpy_RG_npy.npy')
CT = np.load(CT)
CT_out = np.expand_dims(np.array(CT, dtype=np.float32), axis=-1).transpose((3, 2, 1, 0))
CT_out = torch.tensor(CT_out)
catheter = []
while len(catheter) == 0:
a = skeleton2.mapping(CT)
skel = a.skel
xyz = np.where(skel == 1)
idx = np.random.randint(len(xyz[0]), size=2)
sp = np.array([xyz[0][idx[0]], xyz[1][idx[0]], xyz[2][idx[0]]])
fp = np.array([xyz[0][idx[1]], xyz[1][idx[1]], xyz[2][idx[1]]])
catheter = a.get_road(sp, fp)
catheter = np.array(catheter)
C = np.zeros_like(CT)
C[catheter[:, 0], catheter[:, 1], catheter[:, 2]] = 1
C = np.expand_dims(np.array(C, dtype=np.float32), axis=-1).transpose((3, 2, 1, 0))
C = torch.tensor(C)
# T = torch.zeros(6, dtype=torch.float32)
for i, T in enumerate(self.label, 1):
drr = utils.DRR_generation(C, torch.tensor(T, dtype=torch.float32).view(1, 6), 1, [256, 256])
drr_path = os.path.join(f, "{}_{}_{}_{}_{}_{}".format(str(T[0]), str(T[1]), str(T[2]), str(T[3]), str(T[4]), str(T[5])))
np.save(drr_path, drr.cpu().numpy())
# m = "{}_{}_{}_{}_{}_{}_{}\n".format(f, str(T[0]), str(T[1]), str(T[2]), str(T[3]), str(T[4]), str(T[5]))
# file.write(m)
# im = drr.view((960, 1240)).cpu().numpy()
# self.drr_win = utils.PlotImage(vis=self.vis, img=im, win=self.drr_win, title="DRR")
# ct_mean = torch.mean(CT_out)
# ct_std = torch.std(CT_out)
# CT_out = (CT_out - ct_mean) / ct_std
file.close()
def __getitem__(self, index):
"""
:param index:
:return: CT_out: [C, D, H, W] == [1, 393, 512, 512]
:return drr: [C, H, W]
:return T : [6]
"""
return 0
def __len__(self):
return self.label.size
if __name__ == "__main__":
# train_path = './registration/2D3D_Data/train'
# test_path = './registration/2D3D_Data/test'
train_path = '/home/srk1995/pub/db/Unet_1024/Train/'
test_path = '/home/srk1995/pub/db/Unet_1024/Test/'
# cTdataloader = Data(root, transform=transforms.ToTensor())
kdata_train = SegData(train_path, train=True, transform=transforms.ToTensor())
kdata_test = SegData(test_path, train=False, transform=transforms.ToTensor())
# for i, data in enumerate(trainloader):
# print(data)
print("EOP")