# Interactive tool for estimation of parameters in quantitative MRI

# Visualization of the results : $\color{red}{\text{ONLINE FRAMEWORK}}$ 

## Visiting Student : Quentin Duchemin
## Profesors : Carlos Fernandez Granda & Jakob Assländer



### 1) Presentation of the project

We try to estimate the biological parameters in quantitative MRI using neural networks. This notebook is an interactive tool allowing you to visualize the results of your trainings. Make sure that you want to use fingerprints that are computed $\color{red}{\text{ONLINE}}$.

### 2) How to use this interface ?

This tool allows to visualize the results of training. 

In [4]:
! pip install ipywidgets
! jupyter nbextension enable --py widgetsnbextension
! pip install seaborn
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import scipy as sc
from os import listdir
from os.path import isfile, join
import importlib
from IPython.display import clear_output
import numpy as np
import torch
import seaborn as sns
import matplotlib.pyplot as plt
import sys
sys.path.append('..')
from MRF.Training_parameters import *
from MRF.BaseModel import *
from MRF.Projection import *
from MRF.models import *
from MRF.Training_parameters import *
from MRF.Online import Network, Data_class, Performances
import pickle
clear_output()

# Visualize results

In [6]:
ALL = 'ALL'
def unique_sorted_values_plus_ALL(array):
    unique = array.unique().tolist()
    unique.sort()
    unique.insert(0, ALL)
    return unique
import ipywidgets as widgets
from IPython.display import display
fnetworks = [f for f in listdir('../save_networks_online') if isfile(join('../save_networks_online', f))]
exclude = ('num_files_validation','device','NN','urls','gradients','urls_file','training_relative_errors','losses_validation','params_validation','t','x','validation_relative_errors','small_validation_relative_errors','dico_validation','losses','losses_small_validation','losses_small_validation')
for i in range(len(fnetworks)):
    net = torch.load(join('../save_networks_online',fnetworks[i]),map_location='cpu')

In [7]:
lsnllsm0s = np.linspace(0.1,0.5,12)
lsnllst1 = np.logspace(np.log10(0.1),np.log10(6),12)
lsnllst2 = np.logspace(np.log10(0.01),np.log10(3),12)
item_layout = widgets. Layout(
    display='flex',
    justify_content='space-between'
)
style = {'description_width': '250px'}
layout = {'width': '500px'}

def moving_average(a, n=3) :
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

def print_inventory(dct):
    print("SETTINGS:")
    for item, amount in dct.items():
        print("{} ({})".format(item, amount))
        
def gram_schmidt(V):
    n,r = V.shape
    Vortho = np.zeros((n,r))
    Vortho[:,0] = V[:,0] / np.linalg.norm(V[:,0])
    for i in range(1,r):
        Vortho[:,i] = V[:,i]
        for j in range(i):
            Vortho[:,i] -= np.vdot(Vortho[:,j],V[:,i]) * Vortho[:,j]
        Vortho[:,i] /= np.linalg.norm(Vortho[:,i])
    return Vortho
    
def common_filtering(net, name, type_plot, mvavg=1, indice=1, folder='Offline', num=0, freqcut=100):
    model = importlib.import_module('MRF.models.'+net['name_model'])
    PLOT.clear_output()
    with open('../settings_files_online/settings_'+name.replace('network_','')+'.pkl', 'rb') as f:
        settings = pickle.load(f)
        training_parameters = Training_parameters(settings['batch_size'], 1, settings['nb_epochs'], settings['params'], settings['normalization'])
        projection = Projection(settings['start_by_projection'], settings['dimension_projection'], settings['initialization'], settings['normalization'])
        data_class = Data_class(training_parameters, settings['noise_type'], settings['noise_level'], settings['minPD'], 
                                      settings['maxPD'], settings['sampling'],settings['min_values'],settings['max_values'], settings['t2_wrt_t1'])
        validation_settings = {'validation': settings['validation'],'validation_size': settings['validation_size']}
        netw = model.model(projection=projection,nb_params=len(settings['params']))
        device = torch.device('cpu')
        netw.load_state_dict(net['NN'])
        netw.load_state_dict(net['NN'])
        netw.eval()
        estimation = Network(settings['model'], settings['loss'], training_parameters, settings['save_name'], data_class, settings['save_samples'], validation_settings, projection=projection)
    with PLOT:
        if type_plot == 'settings':
            d = dict((k,net[k]) for k in net.keys() if k not in exclude)
            print_inventory(d)

        elif type_plot == 'first_layer':
            sns.set()
            layer = np.array(netw.fc1.weight.data)
            n,m = layer.shape
            c = 0
            fig, ax = plt.subplots(nrows=int((n+1)//2), ncols=2, figsize=(10,35))  
            for i in range(n):
                c += 1
                c = c % 2
                ax[i//2,c].plot(layer[i,:])
                ax[i//2,c].set_xlabel('time (in s.)')
            plt.show()

            button = widgets.Button(description="CLICK HERE to save those weights !", layout=widgets.Layout(width='50%', height='80px'))
            output = widgets.Output()
            display(button, output)
            def on_button_clicked(b):
                with output:
                    sc.io.savemat('weights_first_layer.mat',layer)
                    sc.io.savemat('weights_first_layer_orthonormal.mat',gram_schmidt(layer.T).T)
            button.on_click(on_button_clicked)
            
        elif type_plot == 'Basis functions for the projection subspace':
            
            from scipy.fftpack import fft, ifft, fftfreq
            
            layer = np.array(netw.fc1.weight.data)
            n,m = layer.shape
            u,s,vt = np.linalg.svd(layer)
            n,m = u.shape
            W = fftfreq(vt.shape[1], d=1) 
            c = 0
            fig, ax = plt.subplots(nrows=int((n+1)//2), ncols=2, figsize=(10,35))  
            for i in range(n):
                f_signal = fft(vt[i,:])
                # If our original signal time was in seconds, this is now in Hz    
                cut_f_signal = f_signal.copy()
                cut_f_signal[(np.abs(W)>freqcut)] = 0

                cut_signal = ifft(cut_f_signal)
                c += 1
                c = c % 2
                ax[i//2,c].plot(cut_signal)
                ax[i//2,c].set_xlabel('time (in s.)')
            plt.show()


        elif type_plot == 'NN VS NLLS and CRB':
            MEANnlls, STDnlls, MEANnet, STDnet, CRBs, trueparams = estimation.local_study(netw, os.path.join('../noise_files', folder))
            pos = [j+1 for j in range(len(estimation.trparas.params))]
            bars = ('CRB','NN', 'NLLS')
            
            if num > CRBs.shape[0]:
                print('No fingerprint exists for this number. Please choose a number between 0 and '+str(CRBs.shape[0]-1))
                for numero in range(CRBs.shape[0]):
                    print('True parameters', trueparams[numero,:])
                    plt.figure(figsize=(12,3))
                    params = estimation.trparas.params
                    nbparams = len(params)
                    for j in range(nbparams):
                        plt.style.use('seaborn-white')
                        plt.subplot(1,nbparams,j+1)
                        plt.ylabel(paramtoname[params[j]]) #CRBs[numero,j]
                        plt.errorbar(1, trueparams[numero,params[j]], 0.00001, fmt='o', color='black', ecolor='lightgray', elinewidth=3, capsize=0, label='CRB')
                        plt.errorbar(2, MEANnet[numero,j], STDnet[numero,j], fmt='o', color='red', ecolor='lightgray', elinewidth=3, capsize=0, label='NN')
                        plt.errorbar(3, MEANnlls[numero,j], STDnlls[numero,j], fmt='o', color='green', ecolor='lightgray', elinewidth=3, capsize=0, label='NLLS')
                        plt.xticks(pos, bars)
                        plt.tight_layout()
                    plt.show()
            else:
                plt.figure(figsize=(12,3))
                params = estimation.trparas.params
                nbparams = len(params)
                print('True parameters', trueparams[num,:])

                for j in range(nbparams):
                    plt.style.use('seaborn-white')
                    plt.subplot(1,nbparams,j+1)
                    plt.ylabel(paramtoname[params[j]])
                    plt.errorbar(1, trueparams[num,params[j]], CRBs[num,j], fmt='o', color='black', ecolor='lightgray', elinewidth=3, capsize=0, label='CRB')
                    plt.errorbar(2, MEANnet[num,j], STDnet[num,j], fmt='o', color='red', ecolor='lightgray', elinewidth=3, capsize=0, label='NN')
                    plt.errorbar(3, MEANnlls[num,j], STDnlls[num,j], fmt='o', color='green', ecolor='lightgray', elinewidth=3, capsize=0, label='NLLS')
                    plt.xticks(pos, bars)
                    plt.tight_layout()
                plt.show()
                for j in range(nbparams):
                    print(paramtoname[params[j]]+' :',trueparams[num,params[j]])
                    print('                      sqrt{CRB} : ', round(CRBs[num,j],4) )
                    print('NEURAL NETWORK mean : ', round(MEANnet[num,j],4), '    std : ', round(STDnet[num,j],4), '    sqrt{CRB} / std   : ', round(CRBs[num,j]/np.float(STDnet[num,j]),4) )
                    print('NON LINEAR LS  mean : ', round(MEANnlls[num,j],4), '    std : ', round(STDnlls[num,j],4), '    sqrt{CRB} / std   : ',round( CRBs[num,j]/np.float(STDnlls[num,j]),4) )
                    print('\n')
                    print('\n')
                    
        elif type_plot == 'relative errors':
            plt.figure(figsize=(12,5*len(net['params'])))
            plt.style.use('seaborn-whitegrid')
            for i in range(len(net['params'])):
                if len(net['params'])==1:
                    if net['validation']:
                        val_relerrors = moving_average(np.array(net['validation_relative_errors'])[indice:], n=mvavg)
                    train_relerrors = np.array(net['training_relative_errors'])[indice:]
                else:
                    if net['validation']:
                        val_relerrors = moving_average(np.array(net['validation_relative_errors'])[indice:,i], n=mvavg)
                    train_relerrors = np.array(net['training_relative_errors'])[indice:,i]
                train_relerrors = moving_average(train_relerrors, n=mvavg)
                size = train_relerrors.shape[0]
                plt.subplot(len(net['params']),1,i+1)
                absc = [mvavg+indice + j for j in range(size)]
                plt.plot(absc,train_relerrors, label='training')
                if net['validation']:
                    plt.plot(absc, val_relerrors, label='validation')
                plt.legend()
                plt.title('Relative error on '+paramtoname[net['params'][i]] +' along epochs',size=20)
                plt.xlabel('Epochs')
            plt.show()

        elif type_plot == 'absolute errors':
            plt.figure(figsize=(12,5*len(net['params'])))
            plt.style.use('seaborn-whitegrid')
            for i in range(len(net['params'])):
                if len(net['params'])==1:
                    if net['validation']:
                        val_relerrors = moving_average(np.array(net['validation_absolute_errors'])[indice:], n=mvavg)
                    train_relerrors = np.array(net['training_absolute_errors'])[indice:]
                else:
                    if net['validation']:
                        val_relerrors = moving_average(np.array(net['validation_absolute_errors'])[indice:,i], n=mvavg)
                    train_relerrors = np.array(net['training_absolute_errors'])[indice:,i]
                train_relerrors = moving_average(train_relerrors, n=mvavg)
                size = train_relerrors.shape[0]
                plt.subplot(len(net['params']),1,i+1)
                absc = [mvavg+indice + j for j in range(size)]
                plt.plot(absc,train_relerrors, label='training')
                if net['validation']:
                    plt.plot(absc, val_relerrors, label='validation')
                plt.legend()
                plt.title('absolute error on '+paramtoname[net['params'][i]] +' along epochs',size=20)
                plt.xlabel('Epochs')
            plt.show()
            
        elif type_plot == 'error':
            VP = net['params_validation']
            with torch.no_grad():
                outputs = netw(net['dico_validation'])
                VPerrors = np.abs(estimation.transform(outputs)-VP[:,net['params']])/VP[:,net['params']]
            VP, VPerrors = np.array(VP), np.array(VPerrors)
            sns.set_style("ticks")        
            fig , ax = plt.subplots()
            if len(net['params'])==1:
                ind = 0
            else:
                ind = 1
            sc = ax.scatter(np.log10(VP[:,2]), np.log10(VP[:,1]), c=VPerrors[:,ind], s=100, edgecolor='')
            plt.colorbar(sc)
            ax.set_xlabel('log T2')
            ax.set_ylabel('log T1')
            plt.title('Validation Error')
            plt.show()
            
        elif type_plot == 'Singular values projection layer':
            layer = np.array(netw.fc1.weight.data)
            n,m = layer.shape
            sns.set()
            layer = np.array(netw.fc1.weight.data)
            u, s, vt = np.linalg.svd(layer)
            plt.plot(s)
            plt.xlabel('Number of the singular values')
            plt.ylabel('Singular values of the first linear layer')
            plt.title('Singular values of the projection layer', size=18)
            plt.show()

        elif type_plot == 'samples':
            sampled_points = net['samples']
            plt.scatter(np.log10(sampled_points[:,2]),np.log10(sampled_points[:,1]))
            plt.xlabel('log(T2)')
            plt.ylabel('log(T1)')
            plt.title('Points sampled during training')
            plt.show()

        elif type_plot == 'gradients wrt loss':
            sns.set_style("ticks")        
            fig , ax = plt.subplots()
            plt.plot(net['losses'],net['gradients'])
            ax.set_xlabel('Training loss per epoch')
            ax.set_ylabel('Average of the gradient norm per epoch')
            plt.title('Norm of the gradients of the parameters wrt the training loss',size=15)
            plt.show()
            
        elif type_plot ==  'loss':
            plt.figure(figsize=(12,5))
            plt.style.use('seaborn-whitegrid')
            train_loss = (net['losses'])[indice:]
            train_loss = moving_average(train_loss, n=mvavg)
            size = train_loss.shape[0]
            val_loss = moving_average((net['losses_validation'])[indice:], n=mvavg)
            plt.subplot()
            absc = [mvavg+indice + j for j in range(size)]
            plt.plot(absc,train_loss, label='training')
            plt.plot(absc, val_loss, label='validation')
            plt.legend()
            plt.xlabel('Epochs')
            plt.title('Loss along epochs',size=20)
            plt.show()

PLOT = widgets.Output()
PLOT.clear_output()
options = {'settings', 'error', 'loss', 'samples', 'absolute errors', 'relative errors', 'NN VS NLLS and CRB', 'gradients wrt loss'}


if net['start_by_projection']:
    options.update({'first_layer', 'Singular values projection layer', 'Basis functions for the projection subspace'})
        
dropdown_type = widgets.Dropdown(options = options, value='settings', description='Visualization : ', style=style, layout=layout)

dropdown_network = widgets.Dropdown(options = fnetworks  , description='Chosen network: ', style=style, layout=layout)

dropdown_indice = widgets.BoundedIntText(min=0, max=100000, value=1, step=1, description='Starting epoch: ', style=style, layout=layout)
mv_avg = widgets.BoundedIntText(min=1, max=100000, value=1, step=1, description='Number of epochs for moving average: ', style=style, layout=layout)

dropdown_folder = widgets.Dropdown(
                    options=next(os.walk('../noise_files'))[1],
                    description='folder nlls:', style=style, layout=layout)

dropdown_num = widgets.BoundedIntText(
                    value=0,
                    description='Number of the fingerprint to study:',
                    disabled=False,
                    min=0,
                    style=style, layout=layout)
dropdown_freqcut = widgets.FloatText(
                    value=10000,
                    description='Cut frequency:',
                    disabled=False,
                    min=0,
                    style=style, layout=layout)

def print_form(**func_kwargs):
    if i.children[1].value in ['relative errors','loss']:
        return {'Network': i.children[0].value,
                'Type': i.children[1].value,
                'Mv Avg': i.children[2].value,
                'Indice': i.children[3].value}
    elif i.children[1].value == 'NN VS NLLS and CRB':
        return {'Network': i.children[0].value,
                'Type': i.children[1].value,
                'Folder': i.children[2].value,
                'Num': i.children[3].value}
    elif i.children[1].value == 'Basis functions for the projection subspace':
        return {'Network': i.children[0].value,
                'Type': i.children[1].value,
                'Freqcut': i.children[2].value}
    else:        return {'Network': i.children[0].value,
                'Type': i.children[1].value}
        

def for_mv_avg(plottype):
    if plottype.new in ['relative errors','loss']:
        new_i = widgets.interactive(print_form, network=dropdown_network, plottype=dropdown_type, mvavg=mv_avg, indice=dropdown_indice)
        i.children = new_i.children
    elif plottype.new == 'NN VS NLLS and CRB':
        new_i = widgets.interactive(print_form, network=dropdown_network, plottype=dropdown_type, 
                                    folder=dropdown_folder, num=dropdown_num)
        i.children = new_i.children
    elif plottype.new == 'Basis functions for the projection subspace':
        new_i = widgets.interactive(print_form, network=dropdown_network, plottype=dropdown_type, 
                                    freqcut=dropdown_freqcut)
        i.children = new_i.children
    else:
        new_i = widgets.interactive(print_form, network=dropdown_network, plottype=dropdown_type)
        i.children = new_i.children
dropdown_type.observe(for_mv_avg, 'value')

button = widgets.Button(description="CLICK HERE to change the network or plot !", layout=widgets.Layout(width='30%', height='80px'))
def on_button_clicked(b):
    net = torch.load(join('../save_networks_online',dropdown_network.value),map_location='cpu')
    common_filtering(net, dropdown_network.value, dropdown_type.value, mvavg=mv_avg.value, indice=dropdown_indice.value,
                                    folder=dropdown_folder.value, num=dropdown_num.value, freqcut = dropdown_freqcut.value)

i = widgets.interactive(print_form, network=dropdown_network, plottype=dropdown_type)
button.on_click(on_button_clicked)
input_widgets = widgets.HBox([i, button], layout=item_layout)
tab = widgets.Tab([PLOT], layout=item_layout)
dashboard = widgets.VBox([input_widgets, PLOT])
display(dashboard)

VBox(children=(HBox(children=(interactive(children=(Dropdown(description='Chosen network: ', layout=Layout(wid…