Deep Learning How to Fit a ball-stick Model to HCP diffusion MRI data

This is an adaptation of the original notebook describing the IVIM fitting approach proposed in "Deep Learning How to Fit an Intravoxel Incoherent Motion Model to Diffusion-Weighted MRI" by Barbieri et al., 2019. A preprint of the paper can be found at: https://arxiv.org/abs/1903.00095

Note that I wrote this code quickly without much care so there's probably some bugs!

Authors: Paddy Slator, Jason Lim, UCL.
p.slator@ucl.ac.uk


In [None]:
# import libraries
import numpy as np
import matplotlib.pyplot as plt
# import nibabel as nib
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as utils
from tqdm import tqdm

In [None]:
def cart2mu(xyz):
    shape = xyz.shape[:-1]
    mu = np.zeros(np.r_[shape, 2])
    r = np.linalg.norm(xyz, axis=-1)
    mu[..., 0] = np.arccos(xyz[..., 2] / r)  # theta
    mu[..., 1] = np.arctan2(xyz[..., 1], xyz[..., 0])
    mu[r == 0] = 0, 0
    return mu


In [None]:
!ls ~/OneDrive\ -\ University\ College\ London/data/HCP/111312_1/T1w/Diffusion/


In [None]:
datadir = "/Users/paddyslator/OneDrive - University College London/data/HCP/111312_1/T1w/Diffusion/"




In [None]:
datadir + "bvals"

In [None]:
bvals = np.loadtxt(datadir + "bvals")
bvecs = np.loadtxt(datadir + "bvecs")

#convert to 
bvals = bvals * 1e-03
#
bvecs = np.transpose(bvecs)



In [None]:
print(np.shape(bvals[:,None]))
print(np.shape(bvecs))

grad = np.concatenate((bvecs,bvals[:,None]),axis=1)

In [None]:
#define the torch models on their own?

__all__ = [
    'ball_stick'
]


def ball_stick(grad,params):
    # extract the parameters
    f = params[:,0].unsqueeze(1)
    Dpar = params[:, 1].unsqueeze(1)
    Diso = params[:, 2].unsqueeze(1)
    theta = params[:, 3].unsqueeze(1)
    phi = params[:, 4].unsqueeze(1)    
    
    g = grad[:,0:2]
    bvals = grad[:,3]

    E = f * stick(grad, Dpar, theta, phi) + (1 - f) * ball(grad, Diso)

    return E


def ball(grad, Diso):
    bvals = grad[:, 3]

    E = torch.exp(-bvals * Diso)
    return E


def stick(grad, Dpar, theta, phi):
    g = grad[:, 0:2]
    bvals = grad[:, 3]

    n = sphere2cart(theta, phi)
          
    print(np.shape(bvals * Dpar))
    print(n)
    
    E = torch.exp(-bvals * Dpar * torch.mm(g, n) ** 2)
    return E

def sphere2cart(theta,phi):   
    n = torch.zeros(3,theta.size(0))
    
    sintheta = torch.sin(theta)
    print(sintheta)
    print(theta)
    print(n)
    
    n[0,:] = torch.squeeze(sintheta * torch.cos(phi))
    n[1,:] = torch.squeeze(sintheta * torch.sin(phi))
    n[2,:] = torch.squeeze(torch.cos(theta))   

In [None]:
#paramstor = torch.tensor([[0.5, 1, 2, 0.2, 0.3],[0.2, 2, 2, 0, 2]])
#gradtor = torch.tensor(grad)

#ball_stick(gradtor, paramstor)

In [None]:
#jump straight in and define the neural network!

class Net(nn.Module):
    def __init__(self, gradient_directions_no0, b_values_no0, nparams):
        super(Net, self).__init__()
        #add grad directions, bvals
        self.gradient_directions_no0 = gradient_directions_no0
        self.b_values_no0 = b_values_no0
        self.fc_layers = nn.ModuleList()
        for i in range(3): # 3 fully connected hidden layers
            self.fc_layers.extend([nn.Linear(len(b_values_no0), len(b_values_no0)), nn.ReLU()])
        self.encoder = nn.Sequential(*self.fc_layers, nn.Linear(len(b_values_no0), nparams))

    def forward(self, X):
        params = torch.abs(self.encoder(X)) 

#         t1_ball = params[:, 0].unsqueeze(1)
#         t1_stick = params[:, 1].unsqueeze(1)
#         lambda_par = params[:, 2].unsqueeze(1)
#         lambda_iso = params[:, 3].unsqueeze(1)
#         Fp = params[:,6].unsqueeze(1)
#         theta = params[:,4].unsqueeze(1)
#         phi = params[:,5].unsqueeze(1)

        D_par = torch.clamp(params[:, 0].unsqueeze(1), min=0.001, max=3)
        D_iso = torch.clamp(params[:, 1].unsqueeze(1), min=0.001, max=3)
        #Fp = torch.clamp(params[:,4].unsqueeze(1), min=0.001, max=1)
        Fp = params[:,4].unsqueeze(1)
        theta = params[:,2].unsqueeze(1)
        phi = params[:,3].unsqueeze(1)
                
        mu_cart = torch.zeros(3,X.size()[0])
        sintheta = torch.sin(theta)
        mu_cart[0,:] = torch.squeeze(sintheta * torch.cos(phi))
        mu_cart[1,:] = torch.squeeze(sintheta * torch.sin(phi))
        mu_cart[2,:] = torch.squeeze(torch.cos(theta))
                
        X = Fp*torch.exp(-self.b_values_no0 * D_iso) + (1-Fp)*torch.exp(-self.b_values_no0 * D_par * torch.einsum("ij,jk->ki",self.gradient_directions_no0, mu_cart) ** 2)
        return X, D_par, D_iso, mu_cart, Fp
    
    

In [None]:
# Network
nparams = 5
b_values_no0 = torch.FloatTensor(bvals)
gradient_directions_no0 = torch.FloatTensor(bvecs)
net = Net(gradient_directions_no0, b_values_no0, nparams)

# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr = 0.001) 

In [None]:
import nibabel as nib
#load in some data and try fitting!
imgnii = nib.load(datadir + "data.nii.gz")
masknii = nib.load(datadir + "nodif_brain_mask.nii.gz")

img = imgnii.get_fdata()
mask = masknii.get_fdata()

In [None]:
#image in voxel format
nvoxtotal = np.prod(np.shape(img)[0:3])
nvol = np.shape(img)[3]

imgvox = np.reshape(img,(nvoxtotal,nvol))


In [None]:
#do a smaller mask for now
masktmp = np.zeros(np.shape(mask))
masktmp[:,:,70] = mask[:,:,70]
mask = masktmp

#mask in voxel format
maskvox = np.reshape(mask,(nvoxtotal))

In [None]:
imgvoxtofit = imgvox[maskvox==1]

In [None]:
#normalise the data

#find the volumes to normalise by - max ti, b=0
normvol = np.where(bvals==min(bvals))

imgvoxtofitnorm = imgvoxtofit/(np.tile(np.mean(imgvoxtofit[:,normvol], axis=2),(1, nvol)))



In [None]:
#create batch queues for real data
batch_size = 128
num_batches = len(imgvoxtofitnorm) // batch_size
# X_train = X_train[:,1:] # exlude the b=0 value as signals are normalized
trainloader = utils.DataLoader(torch.from_numpy(imgvoxtofitnorm.astype(np.float32)),
                                batch_size = batch_size, 
                                shuffle = True,
                                num_workers = 2,
                                drop_last = True)



In [None]:
# Best loss
best = 1e16
num_bad_epochs = 0
patience = 20

# Train
for epoch in range(1000): 
    print("-----------------------------------------------------------------")
    print("Epoch: {}; Bad epochs: {}".format(epoch, num_bad_epochs))
    net.train()
    running_loss = 0.

    for i, X_batch in enumerate(tqdm(trainloader), 0):
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        X_pred, D_par_pred, D_iso_pred, mu_pred, Fp_pred = net(X_batch)
        loss = criterion(X_pred, X_batch)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
      
    print("Loss: {}".format(running_loss))
    # early stopping
    if running_loss < best:
        print("############### Saving good model ###############################")
        final_model = net.state_dict()
        best = running_loss
        num_bad_epochs = 0
    else:
        num_bad_epochs = num_bad_epochs + 1
        if num_bad_epochs == patience:
            print("Done, best loss: {}".format(best))
            break
print("Done")
# Restore best model
net.load_state_dict(final_model)

In [None]:
#estimate the real data parameters with the trained network 

net.eval()
with torch.no_grad():
    X_real_pred, D_par, D_iso, mu_cart, Fp = net(torch.from_numpy(imgvoxtofitnorm.astype(np.float32)))
    
X_real_pred = X_real_pred.numpy()
D_par = D_par.numpy()
D_iso = D_iso.numpy()
mu_cart = mu_cart.numpy()
Fp = Fp.numpy()

mu_cart_transposed = mu_cart.transpose()
mu_vals = cart2mu(mu_cart_transposed)
theta = mu_vals[:,0]
phi = mu_vals[:,1]

In [None]:
from scipy import ndimage

D_par_vox = np.zeros(np.shape(maskvox))
D_par_vox[maskvox==1] = np.squeeze(D_par[:])
D_par_map = ndimage.rotate(np.reshape(D_par_vox,np.shape(mask)),90,reshape=False)

D_iso_vox = np.zeros(np.shape(maskvox))
D_iso_vox[maskvox==1] = np.squeeze(D_iso[:])
D_iso_map = ndimage.rotate(np.reshape(D_iso_vox,np.shape(mask)),90,reshape=False)

theta_vox = np.zeros(np.shape(maskvox))
theta_vox[maskvox==1] = np.squeeze(theta[:])
theta_map = ndimage.rotate(np.reshape(theta_vox,np.shape(mask)),90,reshape=False)

phi_vox = np.zeros(np.shape(maskvox))
phi_vox[maskvox==1] = np.squeeze(phi[:])
phi_map = ndimage.rotate(np.reshape(phi_vox,np.shape(mask)),90,reshape=False)

Fp_vox = np.zeros(np.shape(maskvox))
Fp_vox[maskvox==1] = np.squeeze(Fp[:])
Fp_map = ndimage.rotate(np.reshape(Fp_vox,np.shape(mask)),90,reshape=False)

mu_cart_vox = np.zeros((np.shape(maskvox)[0],3))
mu_cart_vox[maskvox==1,:] = np.transpose(mu_cart[:])
mu_cart_map = ndimage.rotate(np.reshape(mu_cart_vox,np.append(np.shape(mask),3)),90,reshape=False)



In [None]:
fig, ax = plt.subplots(5, 1, figsize=(5,20))

zslice = 70

plt0 = ax[0].imshow(D_par_map[:,:,zslice])
plt.colorbar(plt0,ax=ax[0])
ax[0].xaxis.set_ticklabels([]) 
ax[0].set_title('stick parallel diffusivity ($\mu$m$^2$/ms)')
ax[0].axis('off')

plt0 = ax[1].imshow(D_iso_map[:,:,zslice])
plt.colorbar(plt0,ax=ax[1])
ax[1].set_title('ball isotropic diffusivity ($\mu$m$^2$/ms)')
ax[1].axis('off')

plt0 = ax[2].imshow(theta_map[:,:,zslice])
plt.colorbar(plt0,ax=ax[2])
ax[2].set_title('theta')
ax[2].axis('off')

plt0 = ax[3].imshow(phi_map[:,:,zslice])
plt.colorbar(plt0,ax=ax[3])
ax[3].set_title('phi')
ax[3].axis('off')

plt0 = ax[4].imshow(1-Fp_map[:,:,zslice])
plt.colorbar(plt0,ax=ax[4])
ax[4].set_title('stick volume fraction')
ax[4].axis('off')


In [None]:
np.shape(mu_cart_map)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(50,50))

x,y = np.meshgrid(np.linspace(0,np.shape(mu_cart_map)[1],np.shape(mu_cart_map)[1]), np.linspace(0,np.shape(mu_cart_map)[0],np.shape(mu_cart_map)[0]))

u = mu_cart_map[:,:,zslice,0]
v = mu_cart_map[:,:,zslice,1]

plt.quiver(x,y,u,v, headlength=0, headaxislength=0)
plt.show()
