### Make Diffusion MRI image from reconstructed volumes

In [1]:
cd ..

/home/agajan/DeepMRI


In [2]:
import torch
import torch.nn as nn
from deepmri.Datasets import Slice3dDataset
import deepmri.utils as utils
from DiffusionMRI.ConvEncoder import ConvEncoder
from DiffusionMRI.ConvDecoder import ConvDecoder
import torch.optim as optim
import pickle
import matplotlib.pyplot as plt
import nibabel as nib
import time
import os
import numpy as np
%matplotlib inline

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device: ", device)

Device:  cuda:0


In [3]:
encoder = ConvEncoder(input_channels=1)
encoder.to(device)

decoder = ConvDecoder(out_channels=1)
decoder.to(device)

criterion = nn.MSELoss()
encoder_path = '/home/agajan/DeepMRI/DiffusionMRI/models/step1_conv_encoder_epoch_8'
decoder_path = '/home/agajan/DeepMRI/DiffusionMRI/models/step1_conv_decoder_epoch_8'
encoder.load_state_dict(torch.load(encoder_path))
decoder.load_state_dict(torch.load(decoder_path))

encoder.eval()
decoder.eval()

ConvDecoder(
  (decode): Sequential(
    (0): ConvTranspose3d(64, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(0, 1, 0))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): ConvTranspose3d(32, 16, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
    (4): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): ConvTranspose3d(16, 1, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), output_padding=(0, 1, 0))
    (7): BatchNorm3d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [4]:
subj_id = '198451'
data_path = '/media/schultz/345de007-c698-4c33-93c1-3964b99c5df6/regina/'
dmri_path = os.path.join(data_path, subj_id, 'Diffusion/data.nii.gz')
save_path = '/media/schultz/345de007-c698-4c33-93c1-3964b99c5df6/agajan/experiment_DiffusionMRI/reconstructions/'

In [None]:
dmri = nib.load(dmri_path)
dmri_data = dmri.get_fdata()
print(dmri.shape)

(145, 174, 145, 288)


In [None]:
new_data = np.zeros(dmri.shape)

c = 1
with torch.no_grad():
    for i in range(dmri.shape[3]):
        x = dmri_data[:, :, :, i]
        mu = x.mean()
        std = x.std()
        x = (x - mu ) / std  # normalize
        vol = torch.tensor(x).float().unsqueeze(0).unsqueeze(0) # add batch and channel
        out = decoder(encoder(vol.to(device))).detach().cpu().squeeze().numpy()
        new_data[:, :, :, i] = out * std + mu  # get back initial numbers
        if c % 50 == 0:
            print("{}/{}".format(c, dmri.shape[3]))
        c += 1
print("{}/{}".format(dmri.shape[3], dmri.shape[3]))

print("Creating new dMRI image...")
new_dmri_img = nib.Nifti1Image(new_data, dmri.affine, dmri.header)
dst = os.path.join(save_path, 'conv_rec_' + subj_id + '.nii.gz')
print("Saving new dMRI image...")
nib.save(new_dmri_img, dst)
print("Reconstructed dMRI image was saved")

50/288
100/288
150/288
200/288
250/288
288/288
Creating new dMRI image...
Saving new dMRI image...


In [None]:
x_coord = 100
y_coord = 87
z_coord = 73
t = 100

volume_orig = dmri_data[:, :, :, t]
volume_rec = new_data[:, :, :, t]

utils.show_slices([
    volume_orig[x_coord, :, :],
    volume_orig[:, y_coord, :],
    volume_orig[:, :, z_coord],
], figsize=(15, 8), suptitle="Original Image")

utils.show_slices([
    volume_rec[x_coord, :, :].clip(min=0),
    volume_rec[:, y_coord, :].clip(min=0),
    volume_rec[:, :, z_coord].clip(min=0)
], figsize=(15, 8), suptitle="Reconstructed Image with min=0")

utils.show_slices([
    volume_rec[x_coord, :, :],
    volume_rec[:, y_coord, :],
    volume_rec[:, :, z_coord]
], figsize=(15, 8), suptitle="Reconstructed Image (contains negative vals)")

In [None]:
orig_tensor = torch.tensor(volume_orig).float()
rec_tensor = torch.tensor(volume_rec).float()
with torch.no_grad():
    print(criterion(orig_tensor, rec_tensor))

In [None]:
with torch.no_grad():
    print(criterion(orig_tensor, rec_tensor))

In [None]:
torch.min(orig_tensor), torch.min(rec_tensor)