-
Notifications
You must be signed in to change notification settings - Fork 0
/
dixon_rave_abdomen.py
389 lines (321 loc) · 17.3 KB
/
dixon_rave_abdomen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
from cmath import pi
import numpy as np
import os
import gc
import scipy.io
def get_freer_gpu():
os.system('nvidia-smi -q -d Memory |grep -A5 GPU|grep Free >tmp')
memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
return np.argmax(memory_available)
gpu = get_freer_gpu()
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu)
import torch
import torch.nn as nn
from tqdm import tqdm
import imageio
from PIL import Image
import io
import matplotlib.pyplot as plt
import matplotlib.patches as patches
np.random.seed(0)
torch.random.seed()
dtype = torch.float32
device = torch.device("cuda:0")
class PropMRIOffResonanceCorrection(nn.Module):
# __init__ function to initialize values input to model
def __init__(self, dim_x, dim_y, dim_z, initial_val=0.01, blade_width=None, num_views=5, expname='', organ='liver'):
super().__init__()
### Make the log dirs
self.log_train = os.path.join(expname, 'train')
self.log_test = os.path.join(expname, 'test')
os.makedirs(self.log_train, exist_ok=True)
os.makedirs(self.log_test, exist_ok=True)
### Dimensions
self.dim_x, self.dim_y, self.dim_z = dim_x, dim_y, dim_z
### Create image for training
# import complex-valued fat and water images from Dixon-Rave
if organ == 'liver':
fat = scipy.io.loadmat('/home/datasets/dixon_rave_mri/liver/fat.mat')['fat']
water = scipy.io.loadmat('/home/datasets/dixon_rave_mri/liver/water.mat')['water']
else:
fat = scipy.io.loadmat('/home/datasets/dixon_rave_mri/breast/fat.mat')['fat']
water = scipy.io.loadmat('/home/datasets/dixon_rave_mri/breast/water.mat')['water']
self.image_fat_layer = torch.from_numpy(fat).to(device)
self.image_water_layer = torch.from_numpy(water).to(device)
vis = torch.abs(self.image_fat_layer.detach() + self.image_water_layer.detach())
vis = vis - torch.min(vis)
vis = np.asarray((vis * 255).cpu()).astype(np.uint8)
imageio.imwrite(os.path.join(self.log_train, 'gtimg.png'), vis)
vis = torch.abs(torch.cat((self.image_water_layer.detach(), self.image_fat_layer.detach()), dim=1)).detach()
vis = vis - torch.min(vis)
vis = np.asarray((vis/torch.max(vis) * 255).cpu()).astype(np.uint8)
imageio.imwrite(os.path.join(self.log_test, f'gt_fat_water.png'), vis)
### Initialize a model.
A = torch.ones(size=(dim_x, dim_y, dim_z), dtype=dtype, device=device)
B = torch.ones(size=(dim_x, dim_y, dim_z), dtype=dtype, device=device)
self.real_model = nn.Parameter(data=A * initial_val)
self.imag_model = nn.Parameter(data=B * initial_val)
### More realistic training data that has fat and water separate
self.true_real = torch.zeros(size=(dim_x, dim_y, dim_z), dtype=dtype, device=device)
self.true_imag = torch.zeros(size=(dim_x, dim_y, dim_z), dtype=dtype, device=device)
def make_gaussian(x, y, N):
alpha = 6*N
y_scale = 4
return torch.exp(-(1/(2*y_scale*alpha)) * (x - 7*N/8)**2 - (1/(2*alpha)) * (y - 5*N/9)**2).to(device)
xs, ys = torch.meshgrid(torch.arange(dim_x), torch.arange(dim_y), indexing='ij')
xs = xs.to(device)
ys = ys.to(device)
self.true_real[xs, ys, (dim_z/5 + dim_z / 4 * make_gaussian(xs, ys, dim_x)).long()] = torch.real(self.image_water_layer)
self.true_real[xs, ys, (3*dim_z/4 + dim_z / 4 * make_gaussian(xs, ys, dim_x)).long()] = torch.real(self.image_fat_layer)
self.true_imag[xs, ys, (dim_z/5 + dim_z / 4 * make_gaussian(xs, ys, dim_x)).long()] = torch.imag(self.image_water_layer)
self.true_imag[xs, ys, (3*dim_z/4 + dim_z / 4 * make_gaussian(xs, ys, dim_x)).long()] = torch.imag(self.image_fat_layer)
### Generate training and testing angles
all_angles = np.arange(0, pi+0.01, pi/num_views)[1:] # angle 0 is buggy
print(f'training with {len(all_angles)} views')
np.random.shuffle(all_angles)
self.training_angles = all_angles
self.test_angle = pi
### Generate training and testing data => Truth values @ initialization
self.training_data = []
self.training_kspace = []
for theta in self.training_angles:
img_ren = self.render_img(theta, use_overhead = False, use_predicted = False)
# Save copies of the raw projections for later use
np.save(os.path.join(self.log_train, f'pre_blur_train_{theta}.npy'), np.array(img_ren.cpu()))
gt_kspace, gt_img = self.blur_img(img=img_ren, theta=theta)
self.training_data.append(gt_img)
self.training_kspace.append(gt_kspace)
# Save copies of the masks for visualization
mask = self.make_mask(theta, img_ren.shape[0])
vis = mask.detach().cpu()
np.save(os.path.join(self.log_train, f'mask_{theta}.npy'), np.array(mask.cpu()))
vis = np.asarray(vis*255).astype(np.uint8)
imageio.imwrite(os.path.join(self.log_train, f'mask_{theta}.png'), vis)
self.testing_data = self.render_img(self.test_angle, use_overhead = True, use_predicted = False)
vis = torch.abs(self.testing_data.detach())
vis = vis - torch.min(vis)
vis = np.asarray((vis * 255).cpu()).astype(np.uint8)
imageio.imwrite(os.path.join(self.log_test, 'overheadtruth.png'), vis)
def make_mask(self, theta, N):
im = Image.fromarray(np.uint8(np.zeros((N, N))))
fig, ax = plt.subplots()
_ = ax.imshow(im)
width = N
height = N * np.sqrt((3-np.sqrt(5.)) / (5 + np.sqrt(5.)))
x = 0
y = N/2-height / 2
rect = patches.Rectangle((x, y), width=width, height=height, linewidth=0, angle=theta*180/np.pi, rotation_point='center', edgecolor='white', facecolor='white')
ax.add_patch(rect)
io_buf=io.BytesIO()
fig.savefig(io_buf, format='raw')
io_buf.seek(0)
data = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
io_buf.close()
plt.close()
# Crop out our image from the border
ycrop = np.where(data[data.shape[0]//2,:,0]==0)[0] # y min
xcrop = np.where(data[:,data.shape[1]//2,0]==0)[0] # x min
cropped = data[xcrop[0]:xcrop[1], ycrop[0]:ycrop[1]].copy()
# Make it black and white
cropped[cropped < 200] = 0
# Resize back to N by N
im = Image.fromarray(cropped)
im = im.resize((N, N))
cropped = np.array(im)
# Change data type
cropped = np.array(cropped[:,:,0], dtype=np.float32) / 255
mask = torch.from_numpy(cropped).to(device)
return mask
# Blur an image according to a PROPELLER blade k-space mask.
# Expects input to be complex, and returns a complex image
def blur_img(self, img, theta):
assert img.dtype == torch.complex64
mask = self.make_mask(theta, img.shape[0])
img_FFT = torch.fft.fft2(img)
img_DFT = torch.fft.fftshift(img_FFT)
img_DFT_masked = img_DFT * mask
inv_img_DFT_masked = torch.fft.ifftshift(img_DFT_masked)
inv_img_DFT_masked = torch.fft.ifft2(inv_img_DFT_masked)
return img_DFT_masked, inv_img_DFT_masked
# Compute a projection of the volume at the measurement angle. Result is a 2D image.
# Uses both real and imaginary parts of the model, and produces a complex image
def render_img(self, measurement_angle, use_overhead = False, use_predicted = True, which_layer='all'):
xs, ys = torch.meshgrid(torch.arange(self.real_model.shape[0]), torch.arange(self.real_model.shape[1]), indexing='ij')
image_real = self.render_image(xs, ys, measurement_angle, use_overhead, use_predicted, use_real=True, which_layer=which_layer)
image_imag = self.render_image(xs, ys, measurement_angle, use_overhead, use_predicted, use_real=False, which_layer=which_layer)
# Combine into a single complex image
image = torch.complex(image_real, image_imag).to(device)
assert image.dtype == torch.complex64
return image
def render_image(self, pixel_x, pixel_y, measurement_angle, use_overhead, use_predicted, use_real, which_layer='all'):
### pixel_x and pixel_y are 2D vectors with shape [W, H] the pixel indices we want to render in an image
if use_predicted:
if use_real:
grid = self.real_model
else:
grid = self.imag_model
else:
if use_real:
grid = self.true_real
else:
grid = self.true_imag
### get intersection points (x, y, z)
xs, ys, zs, n_pts = self.get_intersection_pts(pixel_x, pixel_y, measurement_angle, use_overhead) # xs, ys, zs, masks should each have shape [W, H, npts]
### normalize, reshape and permute
normalized_xs = (2*xs)/self.dim_x - 1
normalized_ys = (2*ys)/self.dim_y - 1
normalized_zs = (2*zs)/self.dim_z - 1
xs = normalized_xs[None, None, ...]
ys = normalized_ys[None, None, ...]
zs = normalized_zs[None, None, ...]
### render just fat, just water, or all together
if which_layer == 'fat':
keep_idx = (zs >= 0)
xs = xs[keep_idx].reshape(xs.shape[:-1] + (-1,))
ys = ys[keep_idx].reshape(xs.shape[:-1] + (-1,))
zs = zs[keep_idx].reshape(xs.shape[:-1] + (-1,))
n_pts = zs.shape[-1]
elif which_layer == 'water':
keep_idx = (zs < 0)
xs = xs[keep_idx].reshape(xs.shape[:-1] + (-1,))
ys = ys[keep_idx].reshape(xs.shape[:-1] + (-1,))
zs = zs[keep_idx].reshape(xs.shape[:-1] + (-1,))
n_pts = zs.shape[-1]
else:
assert which_layer == 'all'
xs_perm = torch.permute(xs, dims=(4, 1, 2, 3, 0))
ys_perm = torch.permute(ys, dims=(4, 1, 2, 3, 0))
zs_perm = torch.permute(zs, dims=(4, 1, 2, 3, 0))
### points
input_grid = grid[None, None, ...]
input_grid = input_grid.repeat(n_pts, 1, 1, 1, 1)
grid_input = torch.cat((zs_perm, ys_perm, xs_perm), -1).cuda()
points_grid_samp = torch.nn.functional.grid_sample(input_grid, grid_input, mode='bilinear', padding_mode='zeros', align_corners=False)
points = torch.permute(points_grid_samp, dims=(3, 4, 0, 1, 2)).squeeze()
pixel_value = torch.sum(points, dim=-1)
return pixel_value
# Return the points of intersection of the ray based on measurement angle and the pixels
def get_intersection_pts(self, pixel_x, pixel_y, measurement_angle, use_overhead):
### pixel_x and pixel_y are shape [W, H]
if use_overhead:
offset_angle = 0.001/180*pi
else:
offset_angle = 20/180*pi # This is a design parameter that is somewhat linked to the z (omega) resolution
step_size = 0.5
n_pts = (int)(self.real_model.shape[2] * 2 / step_size) # upper bound
x_step = step_size * np.cos(measurement_angle) * np.sin(offset_angle)
y_step = step_size * np.sin(measurement_angle) * np.sin(offset_angle)
z_step = step_size * np.cos(offset_angle)
if (np.cos(measurement_angle) * np.sin(offset_angle) == 0):
xs = torch.zeros(n_pts)[None,None,0:n_pts] + torch.zeros(pixel_x.shape[0], pixel_x.shape[1], 1)
else:
xs = torch.arange(start=0, end=0 + n_pts*x_step, step=x_step)[None,None,0:n_pts] + pixel_x[:,:,None]
if (np.sin(measurement_angle) * np.sin(offset_angle) == 0):
ys = torch.zeros(n_pts)[None,None,0:n_pts] + torch.zeros(pixel_x.shape[0], pixel_x.shape[1], 1)
else:
ys = torch.arange(start=0, end=0 + n_pts*y_step, step=y_step)[None,None,0:n_pts] + pixel_y[:,:,None]
if (np.cos(offset_angle) == 0):
zs = torch.zeros(n_pts)[None,None,0:n_pts] + torch.zeros(pixel_x.shape[0], pixel_x.shape[1], 1)
else:
zs = torch.arange(start=0, end=0 + n_pts*z_step, step=z_step)[None,None,0:n_pts] + torch.zeros(pixel_x.shape[0], pixel_x.shape[1], 1)
### xs, ys, zs should all have shape [W, H, npts]
return xs, ys, zs, n_pts
def compute_tv(self):
# shift the model left, right, up, down, front, back and sum the absolute value differences
model = torch.complex(self.real_model, self.imag_model).to(device)
x_diff = model[1:,:,:] - model[:-1,:,:]
y_diff = model[:,1:,:] - model[:,:-1,:]
z_diff = model[:,:,1:] - model[:,:,:-1]
return torch.mean(torch.abs(x_diff)) + torch.mean(torch.abs(y_diff)) + torch.mean(torch.abs(z_diff))
def compute_sparsity(self):
model = torch.complex(self.real_model, self.imag_model).to(device)
return torch.mean(torch.abs(model))
# Compute loss (difference between our prediction and the actual measurement, a blurred image). Inputs: model (fat and water 3D grids, as a tuple), measurement angle, propeller blade width, (blurred, ground truth) image.
def compute_loss(self, measurement_angle, gt_img, kspace_gt=None, use_overhead=False, tv_lambda=0.01, sparsity_lambda=0.001):
img_predicted = self.render_img(measurement_angle, use_overhead, use_predicted = True) # rendered
if use_overhead:
log_dir = self.log_test
mse_loss = torch.nn.functional.mse_loss(torch.abs(img_predicted), torch.abs(gt_img))
# Also save fat-only and water-only predictions
img_fat = self.render_img(measurement_angle, use_overhead, use_predicted = True, which_layer='fat')
img_water = self.render_img(measurement_angle, use_overhead, use_predicted = True, which_layer='water')
vis = torch.abs(torch.cat((img_water, img_fat), dim=1)).detach()
vis = vis - torch.min(vis)
vis = np.asarray((vis/torch.max(vis) * 255).cpu()).astype(np.uint8)
imageio.imwrite(os.path.join(log_dir, f'fat_water.png'), vis)
else:
assert kspace_gt is not None
kspace_predicted, img_predicted = self.blur_img(img_predicted, measurement_angle) # rendered and blurred
log_dir = self.log_train
mse_loss = 0.000001 * torch.mean(torch.abs(torch.square(kspace_predicted - kspace_gt))) + 0.999999 * torch.mean(torch.abs(torch.square(img_predicted - gt_img)))
### Visualize/imwrite img_blurred and gt_img_blurred
if (log_dir != None):
vis = torch.abs(torch.cat((gt_img, img_predicted), dim=1)).detach()
vis = vis - torch.min(vis)
vis = np.asarray((vis/torch.max(vis) * 255).cpu()).astype(np.uint8)
imageio.imwrite(os.path.join(log_dir, f'angle_{measurement_angle}.png'), vis)
assert mse_loss.dtype == torch.float32
tv_loss = 0
if tv_lambda > 0:
tv_loss = tv_lambda * self.compute_tv()
sparsity_loss = 0
if sparsity_lambda > 0:
sparsity_loss = sparsity_lambda * self.compute_sparsity()
return mse_loss, tv_loss, sparsity_loss
def save_model(self):
np.save(os.path.join(self.log_train, f'gt_real.npy'), np.array(self.true_real.detach().cpu()))
np.save(os.path.join(self.log_train, f'recon_real.npy'), np.array(self.real_model.detach().cpu()))
np.save(os.path.join(self.log_train, f'gt_imag.npy'), np.array(self.true_imag.detach().cpu()))
np.save(os.path.join(self.log_train, f'recon_imag.npy'), np.array(self.imag_model.detach().cpu()))
def save_slices(self, slices):
for sl in slices:
gt = self.true_real[:,:,sl]
pred = self.real_model[:,:,sl]
vis = torch.cat((gt, pred), dim=1).detach()
vis = vis - torch.min(vis)
vis = np.asarray((vis/torch.max(vis) * 255).cpu()).astype(np.uint8)
imageio.imwrite(os.path.join(self.log_test, f'slice_{sl}.png'), vis)
# Function to train and run optimization loop. Includes all functions as components to overall testing of prediction
def train(self, num_epochs, lr=0.0025, tv_lambda=0, sparsity_lambda=0):
best_overhead_psnr = 0.0
optimizer = torch.optim.Adam(self.parameters(), lr=lr)
pb = tqdm(total=num_epochs)
for epoch in range(num_epochs): ### loop over the dataset multiple times
count = 0
train_loss_total = 0
for measurement_angle, train_image, train_kspace in tqdm(zip(self.training_angles, self.training_data, self.training_kspace)):
optimizer.zero_grad() ### zero the parameter gradients
### compute the loss based on model output and the ground truth
mse, tv, sparsity = self.compute_loss(measurement_angle, gt_img=train_image, kspace_gt=train_kspace, use_overhead=False, tv_lambda=tv_lambda, sparsity_lambda=sparsity_lambda)
loss = mse + tv + sparsity
train_loss_total += loss
count += 1
loss.backward() ### backpropagate the loss
optimizer.step() ### adjust parameters based on the calculated gradients
avg_train_loss = train_loss_total/count
### Compute and print the PSNR for this epoch when tested overhead
mseloss, tvloss, sparsityloss = self.compute_loss(self.test_angle, gt_img=self.testing_data, use_overhead=True, tv_lambda=tv_lambda, sparsity_lambda=sparsity_lambda)
overhead_psnr = -10*torch.log10(mseloss)
pb.set_postfix_str(f'Epoch {epoch+1}: train psnr={-10*torch.log10(avg_train_loss):.4f}, overhead psnr={overhead_psnr}, mseloss={mseloss}, tvloss={tvloss}, sparseloss={sparsityloss}', refresh=False)
pb.update(1)
# Save the model if this epoch accuracy is the best
if overhead_psnr > best_overhead_psnr:
self.save_model()
self.save_slices(slices=[self.dim_z//5, self.dim_z//5 + 1, 3*self.dim_z//4, 3*self.dim_z//4 + 1])
best_overhead_psnr = overhead_psnr
torch.cuda.empty_cache()
pb.close()
if __name__ == "__main__":
# Assumes 5 blades that exactly cover the space; blade width is computed automatically
# organ = 'liver'
# reso = 256
organ = 'breast'
reso = 320
tv_lambda = 0.01
sparsity_lambda = 0.1
lr = 0.02
omega = 24
model = PropMRIOffResonanceCorrection(reso, reso, omega, initial_val=0.1, expname=f'{organ}_reso{reso}_{omega}_tv{tv_lambda}_sparsity{sparsity_lambda}_lr{lr}', organ=organ)
model.train(75, lr=lr, tv_lambda=tv_lambda, sparsity_lambda=sparsity_lambda)
print(model.log_test)