# Interactive tool for estimation of parameters in quantitative MRI

# Setting parameters for a new training: $\color{red}{\text{ONLINE FRAMEWORK}}$ 

## Visiting Student : Quentin Duchemin
## Profesors : Carlos Fernandez Granda & Jakob Assländer


### 1) Presentation of the project

We try to estimate the biological parameters in quantitative MRI using neural networks. This notebook is an interactive tool allowing you to define the settings for your next training. Make sure that you want to use fingerprints that are computed $\color{red}{\text{ONLINE}}$.

### 2) How to use this interface ?

This tool allows to save a file containing the settings for a new training. 

### 3) How to launch a training once the settings file is saved ?

To launch a training, follow the instructions below :

- Select the desired parameters and push the button to save the file.
- From the folder **online_built_fingerprints**, launch the command line : **python main.py --save_name 'name_chosen'**. Make sure to replace **'name_chosen'** by the name you chose to save the settings file.

In [2]:
import sys
sys.path.append('..')
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
from os import listdir
from os.path import isfile, join
import importlib
from IPython.display import clear_output
import numpy as np
from MRF.Online import Network
from MRF.Training_parameters import *
import MRF
import warnings
import os
warnings.filterwarnings("ignore")


item_layout = widgets. Layout(
    display='flex',
    justify_content='space-between'
)
style = {'description_width': '250px'} 
layout = {'width': '700px'}

dropdownmodel = widgets.Select(
    options=MRF.models.__all__,
    value='joint_deep3',
    description='Select one architecture.',
    style=style, layout=layout)

initialization =widgets.Select(
    options = Initialization.list(),
    value='Random', 
    description='Initialization for the first linear layer:',
    disabled=False,
    style=style, layout=layout)

dimension_projection = widgets.IntText(
    value=None,
    description='Dimension projection subspace:',
    disabled=False,
    style=style, layout=layout)

start_by_projection = widgets.Checkbox(
    value=True,
    description='Start by a projection',
    disabled=True, layout=layout)

normalization = widgets.Select(
    options = Normalization.list(),
    value='Without',
    description='Normalization:',
    style=style, layout=layout)

namepca = widgets.Text(
    value='basis_gaussian.mat',
    description='Name file PCA:',
    disabled=False,
    style=style, layout=layout)

def print_form_proj(**func_kwargs):
    MOD = importlib.import_module('MRF.models.'+dropdownmodel.value)
    net = MOD.model(40, False, 3)
    if net.start_by_projection:
        iproj.children[1].value = True
        iproj.children[2].options = Normalization.list()
        return {'mod': iproj.children[0].value,
                'Sbp': iproj.children[1].value,
                'Norm': iproj.children[2].value,
                'Ini': iproj.children[3].value,
                'Dim': iproj.children[4].value,
                'name': iproj.children[5].value}
    else:
        iproj.children[1].value = False
        iproj.children[2].options = Normalization.list()[:2] 
        return {'mod': iproj.children[0].value,
                'Sbp': iproj.children[1].value,
                'Norm': iproj.children[2].value}

def for_proj(valid):
    if valid.new:
        new_iproj = widgets.interactive(print_form_proj, mod=dropdownmodel, sbp=start_by_projection, norm=normalization, ini=initialization, dim=dimension_projection, name=namepca)
        iproj.children = new_iproj.children
    else:
        new_iproj = widgets.interactive(print_form_proj, mod=dropdownmodel, sbp=start_by_projection, norm=normalization)
        iproj.children = new_iproj.children
start_by_projection.observe(for_proj, 'value')
PROJ = widgets.Output()
PROJ.clear_output()
iproj = widgets.interactive(print_form_proj, mod=dropdownmodel, sbp=start_by_projection, norm=normalization, ini=initialization, dim=dimension_projection, name = namepca)
input_widgetsproj = widgets.HBox([iproj], layout=item_layout)
dashboardproj = widgets.VBox([input_widgetsproj, PROJ])
display(dashboardproj)

optimizer =widgets.Select(
    options = Optimizer.list(),
    value='Random', 
    description='Initialization for the first linear layer:',
    disabled=False,
    style=style, layout=layout)
display(optimizer)

lr = widgets.BoundedFloatText(
    value=0.001,
    min=0.0000001,
    max=100.,
    step=0.00001,
    description='Learning Rate:',
    disabled=False,
    style=style, layout=layout)
display(lr)

noise_type = widgets.Select(
    options=NoiseType.list(),
    value='Standard',
    description='Noise type:',
    style=style, layout=layout)
display(noise_type)

noise_level = widgets.BoundedFloatText(
    value=1./30.,
    min=0.000000000001,
    max=200.,
    description='Noise Level:',
    style=style, layout=layout)
display(noise_level)

sampling = widgets.Select(
    options=Sampling.list(),
    value='Log',
    description='Sampling:',
    style=style, layout=layout)
display(sampling)

batch_size = widgets.IntText(
    value=64,
    description='Batch Size:',
    disabled=False,
    style=style, layout=layout)
display(batch_size)

nb_iterations = widgets.IntText(
    value=20,
    description='Nb Iterations:',
    disabled=False,
    style=style, layout=layout)
display(nb_iterations)

nb_epochs = widgets.IntText(
    value=10000,
    description='Nb Epochs:',
    disabled=False,
    style=style, layout=layout)
display(nb_epochs)

params_name = widgets.Select(
    options=nametoparam.keys(),
    value='The three parameters', 
    description='Parameters to estimate:',
    disabled=False,
    style=style, layout=layout)
display(params_name)

lossm0s = widgets.Select(
    options=Loss.list(),
    value='MSE-Log', 
    description='Loss m0s:',
    disabled=False,
    style=style, layout=layout)
display(lossm0s)

losst1f = widgets.Select(
    options=Loss.list(),
    value='MSE-Log', 
    description='Loss t1f:',
    disabled=False,
    style=style, layout=layout)
display(losst1f)

losst2f = widgets.Select(
    options=Loss.list(),
    value='MSE-Log', 
    description='Loss t2f:',
    disabled=False,
    style=style, layout=layout)
display(losst2f)

lossr = widgets.Select(
    options=Loss.list(),
    value='MSE-Log', 
    description='Loss r:',
    disabled=False,
    style=style, layout=layout)
display(lossr)

losst1s = widgets.Select(
    options=Loss.list(),
    value='MSE-Log', 
    description='Loss t1s:',
    disabled=False,
    style=style, layout=layout)
display(losst1s)

losst2s = widgets.Select(
    options=Loss.list(),
    value='MSE-Log', 
    description='Loss t2s:',
    disabled=False,
    style=style, layout=layout)
display(losst2s)

ranget1 = widgets.FloatRangeSlider(
    value=[0.1, 6],
    min=0,
    max=7,
    step=0.001,
    description='Range for T1:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.3f',
    style=style, layout=layout)
display(ranget1)

ranget2 = widgets.FloatRangeSlider(
    value=[0.01, 3],
    min=0,
    max=4,
    step=0.001,
    description='Range for T2:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.3f',
    style=style, layout=layout)
display(ranget2)

rangem0s = widgets.FloatRangeSlider(
    value=[0.1, 0.5],
    min=0,
    max=0.7,
    step=0.001,
    description='Range for m0s:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.3f',
    style=style, layout=layout)
display(rangem0s)

ranger = widgets.FloatRangeSlider(
    value=[10, 100],
    min=0,
    max=110,
    step=1,
    description='Range for R:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    style=style, layout=layout)
display(ranger)

ranget2s = widgets.FloatRangeSlider(
    value=[0.001, 0.1],
    min=0.001,
    max=0.1,
    step=0.001,
    description='Range for T2s:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='0.3f',
    style=style, layout=layout)
display(ranget2s)

minPD = widgets.FloatText(
    value=0.1,
    description='minimun proton density:',
    disabled=False,
    style=style, layout=layout)
display(minPD)

maxPD = widgets.FloatText(
    value=1.,
    description='maximum proton density:',
    disabled=False,
    style=style, layout=layout)
display(maxPD)

t2_wrt_t1 =  widgets.Select(
    options=T2wrtT1.list(),
    value='below_percent', 
    description='T2 wrt T1:',
    disabled=False,
    style=style, layout=layout)
display(t2_wrt_t1)

save_name = widgets.Text(
    value='1',
    placeholder='Enter the name desired',
    description='Saving name:',
    disabled=False, layout=layout)
display(save_name)

save_samples = widgets.Checkbox(
    value=True,
    description='Save samples during training',
    disabled=False, layout=layout)
display(save_samples)

validation = widgets.Checkbox(
    value=True,
    description='Keep validation error',
    disabled=False, layout=layout)

validation_size =widgets.IntText(
    value=1000,
    description='Validation size:',
    disabled=False,
    style=style, layout=layout)

def print_form_valid(**func_kwargs):
    if i.children[0].value:
        return {'Valid': i.children[0].value,
                'Svs': i.children[1].value}
    else:
        return {'Valid': i.children[0].value}
def for_valid(valid):
    if valid.new:
        new_i = widgets.interactive(print_form_valid, valid=validation, svs=validation_size)
        i.children = new_i.children
    else:
        new_i = widgets.interactive(print_form_valid, valid=validation)
        i.children = new_i.children
validation.observe(for_valid, 'value')
VALID = widgets.Output()
VALID.clear_output()
i = widgets.interactive(print_form_valid, valid=validation, svs=validation_size)
input_widgets = widgets.HBox([i], layout=item_layout)
dashboard = widgets.VBox([input_widgets, VALID])
display(dashboard)


button = widgets.Button(description="CLICK HERE to save the settings previously defined for the training !", button_style='success', layout=widgets.Layout(width='50%', height='80px'))
output_params = widgets.Output()


def on_button_clicked(b):
    with output_params:
        dic = {}
        dic['optimizer'] = optimizer.value
        dic['lr'] = lr.value
        dic['model'] = dropdownmodel.value
        dic['noise_type'] = noise_type.value
        dic['noise_level'] = noise_level.value
        dic['sampling'] = sampling.value
        dic['loss'] = [lossm0s.value, losst1f.value, losst2f.value, lossr.value, losst1s.value, losst2s.value]
        dic['batch_size'] = batch_size.value
        dic['nb_iterations'] = nb_iterations.value
        dic['nb_epochs'] = nb_epochs.value
        dic['params'] = nametoparam[params_name.value]
        dic['initialization'] = initialization.value
        dic['save_samples'] = save_samples.value
        dic['validation'] = validation.value
        dic['validation_size'] = validation_size.value
        dic['start_by_projection'] = start_by_projection.value
        dic['minPD'] = minPD.value
        dic['maxPD'] = maxPD.value
        dic['normalization'] = normalization.value
        dic['t2_wrt_t1'] = t2_wrt_t1.value
        dic['min_values'] = [rangem0s.value[0],ranget1.value[0],ranget2.value[0],ranger.value[0],ranget2s.value[0]]
        dic['max_values'] = [rangem0s.value[1],ranget1.value[1],ranget2.value[1],ranger.value[1],ranget2s.value[1]]
        dic['save_name'] = save_name.value
        dic['dimension_projection'] = dimension_projection.value
        import pickle
        if not os.path.exists('../settings_files_online'):
            os.mkdir('../settings_files_online')
        f = open("../settings_files_online/settings_"+save_name.value+".pkl","wb")
        pickle.dump(dic,f)
        f.close()
        print('Settings have been saved.')
button.on_click(on_button_clicked)

display(button, output_params)

VBox(children=(HBox(children=(interactive(children=(Select(description='Select one architecture.', index=13, l…

BoundedFloatText(value=0.001, description='Learning Rate:', layout=Layout(width='700px'), min=1e-07, step=1e-0…

Select(description='Noise type:', index=1, layout=Layout(width='700px'), options=('SNR', 'Standard'), style=De…

BoundedFloatText(value=0.03333333333333333, description='Noise Level:', layout=Layout(width='700px'), max=200.…

Select(description='Sampling:', layout=Layout(width='700px'), options=('Log', 'Uniform'), style=DescriptionSty…

IntText(value=64, description='Batch Size:', layout=Layout(width='700px'), style=DescriptionStyle(description_…

IntText(value=20, description='Nb Iterations:', layout=Layout(width='700px'), style=DescriptionStyle(descripti…

IntText(value=10000, description='Nb Epochs:', layout=Layout(width='700px'), style=DescriptionStyle(descriptio…

Select(description='Parameters to estimate:', index=3, layout=Layout(width='700px'), options=('m0s', 'T1', 'T2…

Select(description='Loss m0s:', layout=Layout(width='700px'), options=('MSE-Log', 'MSE', 'MSE-Inverse', 'MSE-S…

Select(description='Loss t1f:', layout=Layout(width='700px'), options=('MSE-Log', 'MSE', 'MSE-Inverse', 'MSE-S…

Select(description='Loss t2f:', layout=Layout(width='700px'), options=('MSE-Log', 'MSE', 'MSE-Inverse', 'MSE-S…

Select(description='Loss r:', layout=Layout(width='700px'), options=('MSE-Log', 'MSE', 'MSE-Inverse', 'MSE-Sca…

Select(description='Loss t1s:', layout=Layout(width='700px'), options=('MSE-Log', 'MSE', 'MSE-Inverse', 'MSE-S…

Select(description='Loss t2s:', layout=Layout(width='700px'), options=('MSE-Log', 'MSE', 'MSE-Inverse', 'MSE-S…

FloatRangeSlider(value=(0.1, 6.0), continuous_update=False, description='Range for T1:', layout=Layout(width='…

FloatRangeSlider(value=(0.01, 3.0), continuous_update=False, description='Range for T2:', layout=Layout(width=…

FloatRangeSlider(value=(0.1, 0.5), continuous_update=False, description='Range for m0s:', layout=Layout(width=…

FloatRangeSlider(value=(10.0, 100.0), continuous_update=False, description='Range for R:', layout=Layout(width…

FloatRangeSlider(value=(0.001, 0.1), continuous_update=False, description='Range for T2s:', layout=Layout(widt…

FloatText(value=0.1, description='minimun proton density:', layout=Layout(width='700px'), style=DescriptionSty…

FloatText(value=1.0, description='maximum proton density:', layout=Layout(width='700px'), style=DescriptionSty…

Select(description='T2 wrt T1:', index=2, layout=Layout(width='700px'), options=('no_constraint', 'below', 'be…

Text(value='1', description='Saving name:', layout=Layout(width='700px'), placeholder='Enter the name desired'…

Checkbox(value=True, description='Save samples during training', layout=Layout(width='700px'))

VBox(children=(HBox(children=(interactive(children=(Checkbox(value=True, description='Keep validation error', …

Button(button_style='success', description='CLICK HERE to save the settings previously defined for the trainin…

Output()

In [None]:
import pickle
with open('settings_files/settings_21.pkl', 'rb') as f:
    default_settings = pickle.load(f)
    print(default_settings)

In [None]:
random.seed()
		np.random.seed()
		m0s = random.uniform(0,0.7)
		t1 = 2.8 * random.uniform(0,1) + 0.2
		t2f = t1 * ( random.uniform(0,1) * 0.5 + 0.005 )
		r = 490 * random.random(0,1) + 10
		t2s= 0.2 * 10**(-3) + random.uniform(0,1) * 150 * 10**(-3)