In [None]:
#make sure stuff is on the path
import sys
sys.path.append('/Users/paddyslator/python/microtorchfit/')


In [None]:
# fit.py
import argparse
import numpy as np
import nibabel as nib
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as utils
from tqdm import tqdm      
import matplotlib.pyplot as plt 

In [None]:
# WRITE (OR FIND!!) A FUNCTION THAT NORMALISES A GENERIC IMAGE!
# MAKE SOME SIMULATED DATA TO TRY ON!
# NOISE FLOOR 
# MAGNITUDE/COMPLEX?

In [None]:
DATADIR = '/Users/paddyslator/Library/CloudStorage/OneDrive-UniversityCollegeLondon/data/brain-dki/'

DATADIR = '/Users/paddyslator/Library/CloudStorage/OneDrive-UniversityCollegeLondon/data/HCP/111312_1/T1w/Diffusion/'


In [None]:
#convert Siemens style gradient table to grad
siemensgradfilename = 'DiffDir_Spiral_Vec15Dir3bmax2800.txt'

def siemens_to_grad(filename,maxb):
    grad_dirs = np.loadtxt(filename)
    
    #scaling factors of the grad_dirs
    grad_dirs_scale = np.linalg.norm(grad_dirs,axis=1) ** 2
        
    #calculate the b-values
    bvals = grad_dirs_scale * maxb        
    
    #normalise the grad_dirs
    grad_dirs_scale = grad_dirs_scale.reshape(-1, 1)
        
    grad_dirs[bvals!=0,:] = grad_dirs[bvals!=0,:] / np.sqrt(grad_dirs_scale[bvals!=0])
    
    grad = np.concatenate((grad_dirs,bvals[:,None]),axis=1)
    
    return grad
    
    
grad = siemens_to_grad(DATADIR + siemensgradfilename, 2800)

np.savetxt(DATADIR + siemensgradfilename.split('.')[0] + '_grad.txt', grad)

In [None]:
imgfilename = 'imageMatrixCpx.nii.gz'
maskfilename = 'imageMatrixCpx_mask.nii.gz'
gradfilename = 'DiffDir_Spiral_Vec15Dir3bmax2800_grad.txt'

imgfilename = 'data.nii.gz'
maskfilename = 'nodif_brain_mask.nii.gz'
gradfilename = 'grad.b'


imgnii = nib.load(DATADIR + imgfilename)
img = imgnii.get_fdata()

masknii = nib.load(DATADIR + maskfilename)
mask = masknii.get_fdata()

grad = np.loadtxt(DATADIR + gradfilename)

#round b-values to nearest integer
grad[:,3] = np.round(grad[:,3])

#convert to microns
grad[:,3]= 1e-3 * grad[:,3]


#grad[:,3] = 1e-3*(grad[:,3] - np.min(grad[:,3]))

#remove the first n images

# #0 removes 1, 1 removes 2 etc
# n_start = 0

# img = img[:,:,:,n_start:]
# grad = grad[n_start:,:]


In [None]:
import dipy
#baseline with dipy implementation

#make gradient table
from dipy.core.gradients import gradient_table
gtab = gradient_table(1e3 * grad[:,3], grad[:,0:3])

# Reconstruction modules
import dipy.reconst.msdki as msdki

msdki_model = msdki.MeanDiffusionKurtosisModel(gtab)


#do the fit
msdki_fit = msdki_model.fit(img, mask)

#save the maps as nifti
MSD = msdki_fit.msd
MSK = msdki_fit.msk

#save the inferred maps as niftis
maps = np.stack((MSD, MSK),axis=-1)

#use the image as a template
mapsnii = nib.Nifti1Image(maps, affine=imgnii.affine,header=imgnii.header)
#adjust 4th spatial dimension
mapsnii.header['dim'][4] = np.shape(maps)[-1]

nib.save(mapsnii, DATADIR + imgfilename[0:-7] + '_DIPY_DK_maps.nii.gz')  




In [None]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 2, figsize=(10,5))

zslice=70

plt0 = ax[0].imshow(MSK[:,:,zslice],vmin=0,vmax=2)
# plt.colorbar(plt0,ax=ax[0])
# ax[0].xaxis.set_ticklabels([]) 
# ax[0].set_title('diffusivity ($\mu$m$^2$/ms)')
# #ax[0].set_title('diffusivity (mm$^2$/ms)')
# ax[0].axis('off')

# plt0 = ax[1].imshow(T2_map[:,:,zslice],cmap='hot',vmin=0,vmax=.06)
# plt.colorbar(plt0,ax=ax[1])    
# ax[1].xaxis.set_ticklabels([]) 
# ax[1].set_title('T2 (s)')
# ax[1].axis('off')
    
plt0 = ax[1].imshow(MSD[:,:,zslice],cmap='plasma')
# plt.colorbar(plt0,ax=ax[1])    
# ax[1].xaxis.set_ticklabels([]) 
# ax[1].set_title('kurtosis')
# ax[1].axis('off') 

In [None]:
#preprocess data
from utils.preprocessing import direction_average

#take the spherical mean
da_img,da_grad = direction_average(img,grad)


In [None]:
#make a smaller mask for testing
tmpmask = np.zeros_like(mask)
tmpmask[:,:,70] = mask[:,:,70]
mask=tmpmask

In [None]:
#preprocess for machine learning! MAKE A FUNCTION OF THIS!
#(todo)

#define some useful functions
def normalise(X_train,grad):
    nvol = np.shape(grad)[0]
    
    #normalise 
    #find the volumes to normalise by - the lowest b-value lowest TE volume
    #ADD SOME TOLERANCE TO THIS
    #normvol = np.where((grad[:,3] == min(grad[:,3])) & (grad[:,4]==min(grad[:,4])))
    
    #this just works for diffusion MRI - need to change if multiple echo times etc.
    normvol = np.where(grad[:,3] == min(grad[:,3]))[0]
                       
    if len(normvol)>1:
        X_train = X_train/(np.tile(np.mean(X_train[:,normvol], axis=1),(1, nvol)))
    else:
        X_train = X_train/(np.tile(X_train[:,normvol],(1, nvol)))
    
    return X_train


def img2voxel(img,mask):
    nvoxtotal = np.prod(np.shape(img)[0:3])
    nvol = np.shape(img)[3]
    #image in voxel format
    imgvox = np.reshape(img,(nvoxtotal,nvol))
    #mask in voxel format
    maskvox = np.reshape(mask,(nvoxtotal))
    #extract the voxels in the mask
    X_train = imgvox[maskvox==1]    
    
    return X_train,maskvox


#flatten/voxelise
X_train,maskvox = img2voxel(da_img,mask)

print(np.shape(X_train))
print(np.shape(maskvox))


# nvoxtotal = np.prod(np.shape(da_img)[0:3])
# nvol = np.shape(da_img)[3]
# #image in voxel format
# imgvox = np.reshape(da_img,(nvoxtotal,nvol))
# #mask in voxel format
# maskvox = np.reshape(mask,(nvoxtotal))
# #extract the voxels in the mask
# X_train = imgvox[maskvox==1]

#normalise using the function
X_train = normalise(X_train,da_grad)
    

In [None]:
# #simulated data to try the model on 
# import numpy as np

# #simulate some data from a "cluster model"
# nvox = 1024
# nclus = 5
# p = [0.1, 0.1, 0.2, 0.5]


# p = np.append(p,1-np.sum(p))
# clusters = np.random.choice(range(0,nclus),size=(nvox,),p=p)

# #define the underlying tissue parameters for each cluster
# D = [0.5,1,1.5,2,3]
# K = [1,0.5,0.2,0.1,0.01]
# #K = [0.1,0.05,0.2,0.1,0]

# mu = np.stack((D,K))
# var = np.diag([0.01,0.01])


# params = np.zeros((nvox,2))

# for vox in range(0,nvox):
#     params[vox,:] = np.random.multivariate_normal(mu[:,clusters[vox]],var)
    

# from signal_models import msdki

# tor_params = torch.from_numpy(params)
# tor_grad = torch.from_numpy(da_grad) 
# tor_grad = tor_grad.to(torch.float32)

# S = msdki(tor_grad,tor_params)

# X_train = S.numpy()

In [None]:
nparams = 2
#define the neural network - change to import this from elsewhere! 
class Net_test(nn.Module):
    def __init__(self, grad, nparams): #PASS MODEL STRING AS AN ARGUMENT IN HERE!
        super(Net_test, self).__init__()

        self.grad = grad
        self.fc_layers = nn.ModuleList()
        
        for i in range(3): # 3 fully connected hidden layers
            self.fc_layers.extend([nn.Linear(grad.size(0), grad.size(0)), nn.ELU()])
        self.encoder = nn.Sequential(*self.fc_layers, nn.Linear(grad.size(0), nparams))
        
        #self.dropout = nn.Dropout(0.5)

    def forward(self, X):
        #X = self.dropout(X)
        params = torch.abs(self.encoder(X)) # D, T2, K
        D = torch.clamp(params[:, 0].unsqueeze(1), min=0.001, max=3)
        K = torch.clamp(params[:, 1].unsqueeze(1), min=0.001, max=2)
         
#         D = params[:, 0].unsqueeze(1)
#         K = params[:, 1].unsqueeze(1)                
        
        bvals = self.grad[:,3]
        
        X = torch.exp(-bvals*D + 1/6 * bvals**2 * D**2 * K )
        

        
        return X, D, K
    
    

# make the Network
grad = torch.FloatTensor(da_grad)
net_test = Net_test(grad, nparams)



In [None]:
#define the neural network using the functions    

#define the model
#comps = ("MSDKI",)
comps = ("Ball","Stick")

#import dynamically
import importlib
signal_models_module = importlib.import_module("signal_models")

comps_classes = () #initialise tuple
for comp in comps:
    #get the class
    this_class = getattr(signal_models_module, comp) #add to the tuple
    #create an instance of the class and add to the tuple
    comps_classes += (this_class(),)

from model_maker import ModelMaker

modelfunc = ModelMaker(comps_classes)

import torch.nn as nn
from utils.net_maker import Net

#make an example of the network
net = Net(grad, modelfunc, dim_hidden=grad.shape[0], num_layers=3, dropout_frac=0, activation=nn.ELU())




In [None]:
modelfunc.n_params

In [None]:
thing = np.mean(X_train,axis=1)
np.shape(thing)
np.shape(np.tile(thing,(18,1)).T)

In [None]:
teston = False


#initialise the weights
# def weights_init(m):
#     if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
#         nn.init.xavier_uniform_(m.weight.data)
#         nn.init.zeros_(m.bias.data)

# net.apply(weights_init)



# Loss function and optimizer
criterion = nn.MSELoss()

if teston:
    optimizer = optim.Adam(net_test.parameters(), lr = 0.01)  
else:
    optimizer = optim.Adam(net.parameters(), lr = 0.01)  

    
#optimizer = optim.SGD(net.parameters(), lr = 0.01)  


#create batch queues
batch_size = 128
num_batches = len(X_train) // batch_size

#X_train = X_train[:,1:] # exlude the b=0 value as signals are normalized

trainloader = utils.DataLoader(torch.from_numpy(X_train.astype(np.float32)),
                            batch_size = batch_size, 
                            shuffle = True,
                            num_workers = 2,
                            drop_last = True)


#learning rate scheduler
# from torch.optim import lr_scheduler
# scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)



In [None]:

# Best loss
best = 1e16
num_bad_epochs = 0
patience = 10

torch.manual_seed(42)

              
# Train
for epoch in range(100): 
    print("-----------------------------------------------------------------")
    print("Epoch: {}; Bad epochs: {}".format(epoch, num_bad_epochs))
    if teston:
        net_test.train()
    else:
        net.train()
        
    running_loss = 0.

    for i, X_batch in enumerate(tqdm(trainloader), 0):
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        #X_pred, Dp_pred, Dt_pred, Fp_pred = net(X_batch)
        
        if teston:
            X_pred, D_pred, K_pred = net_test(X_batch)
        else:
            X_pred, params_pred = net(X_batch)
        
        
        loss = criterion(X_pred, X_batch)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    print("Loss: {}".format(running_loss))
    # early stopping
    if running_loss < best:
        print("############### Saving good model ###############################")
        if teston:
            final_model = net_test.state_dict()
        else:
            final_model = net.state_dict()
            
        best = running_loss
        num_bad_epochs = 0
    else:
        num_bad_epochs = num_bad_epochs + 1
        if num_bad_epochs == patience:
            print("Done, best loss: {}".format(best))
            break
print("Done")
# Restore best model
if teston:
    net_test.load_state_dict(final_model)
else:
    net.load_state_dict(final_model)

In [None]:
#Inference
if teston:
    net_test.eval()
else:
    net.eval()
    
with torch.no_grad():
    if teston:
        X, D, K = net_test(torch.from_numpy(X_train.astype(np.float32)))
        
        D = D.numpy()
        K = K.numpy()
      
        plt.plot(tor_params[:,0],D,'o')
      #plt.plot(K)
      
      #convert parameters back to image format
      # D_vox = np.zeros(np.shape(maskvox))
      # D_vox[maskvox==1] = np.squeeze(D[:])
      # D_map = np.reshape(D_vox,np.shape(mask))

      # K_vox = np.zeros(np.shape(maskvox))
      # K_vox[maskvox==1] = np.squeeze(K[:])
      # K_map = np.reshape(K_vox,np.shape(mask))
    else:
        X, params = net(torch.from_numpy(X_train.astype(np.float32)))
        params = params.numpy()
      
        #plt.plot(tor_params[:,0],params[:,0],'o')

    







In [None]:
np.max(X_train)

In [None]:
# #rotate images
# from scipy import ndimage

# D_map = ndimage.rotate(D_map,-90,reshape=True)
# K_map = ndimage.rotate(K_map,-90,reshape=True)



In [None]:
param_map = np.zeros((*np.shape(mask),2))

print(np.shape(param_map))

In [None]:
tmpparams = np.zeros_like(maskvox)
tmpparams[maskvox==1] = 2020

In [None]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1, 2, figsize=(10,5))

if teston:
    zslice=28
else:
    D_map = 
    
    
    plt0 = ax[0].imshow(D_map[:,:,zslice],vmin=0,vmax=3)
    plt.colorbar(plt0,ax=ax[0])
    ax[0].xaxis.set_ticklabels([]) 
    ax[0].set_title('diffusivity ($\mu$m$^2$/ms)')
    #ax[0].set_title('diffusivity (mm$^2$/ms)')
    ax[0].axis('off')

    # plt0 = ax[1].imshow(T2_map[:,:,zslice],cmap='hot',vmin=0,vmax=.06)
    # plt.colorbar(plt0,ax=ax[1])    
    # ax[1].xaxis.set_ticklabels([]) 
    # ax[1].set_title('T2 (s)')
    # ax[1].axis('off')

    plt0 = ax[1].imshow(K_map[:,:,zslice],cmap='plasma',vmin=0,vmax=2)
    plt.colorbar(plt0,ax=ax[1])    
    ax[1].xaxis.set_ticklabels([]) 
    ax[1].set_title('kurtosis')
    ax[1].axis('off')    

    

In [None]:
#save the inferred maps as niftis
maps = np.stack((D_map, K_map),axis=-1)

#use the image as a template
mapsnii = nib.Nifti1Image(maps, affine=imgnii.affine,header=imgnii.header)
#adjust 4th spatial dimension
mapsnii.header['dim'][4] = np.shape(maps)[-1]

nib.save(mapsnii, DATADIR + imgfilename[0:-7] + '_DK_maps.nii.gz')  

In [None]:
X[1000,:]

In [None]:
criterion(X_pred, X_batch)

In [None]:
plt.plot(grad[:,3],X_batch[1,:].detach().numpy(),'o')
plt.plot(grad[:,3],X_pred[1,:].detach().numpy(),'x')