# Interactive tool for estimation of parameters in quantitative MRI

# Setting parameters for a new training 


### 1) Presentation of the project

We try to estimate the biological parameters in quantitative MRI using neural networks. This notebook is an interactive tool allowing you to define the settings for your next training.

### 2) How to use this interface ?

This tool allows to save a file containing the settings for a new training. 

### 3) How to launch a training once the settings file is saved ?

To launch a training, follow the instructions below :

- Select the desired parameters and push the button to save the file.
- From the folder **offline_built_fingerprints**, launch the command line : **python main.py --save_name 'name_chosen'**. Make sure to replace **'name_chosen'** by the name you chose to save the settings file.

In [6]:
import sys
sys.path.append('..')
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import os
from os import listdir
from os.path import isfile, join
import importlib
from IPython.display import clear_output
import numpy as np
from MRF.Offline import Network
from MRF.Training_parameters import *
import MRF
import warnings

warnings.filterwarnings("ignore")

item_layout = widgets. Layout(
    display='flex',
    justify_content='space-between'
)
style = {'description_width': '250px'} 
layout = {'width': '700px'}

dropdownmodel = widgets.Select(
    options=MRF.models.__all__,
    value='CRB-paper',
    description='Select one architecture.',
    style=style, layout=layout)

initialization =widgets.Select(
    options=Initialization.list(),
    value='Random', 
    description='Initialization for the first linear layer:',
    disabled=False,
    style=style, layout=layout)

dimension_projection = widgets.IntText(
    value=None,
    description='Dimension projection subspace:',
    disabled=False,
    style=style, layout=layout)

start_by_projection = widgets.Checkbox(
    value=True,
    description='Start by a projection',
    disabled=True, layout=layout)

normalization = widgets.Select(
    options = Normalization.list(),
    value='Without',
    description='Normalization:',
    style=style, layout=layout)

namepca = widgets.Text(
    value='basis_gaussian.mat',
    description='Name file PCA:',
    disabled=False,
    style=style, layout=layout)

def print_form_proj(**func_kwargs):
    MOD = importlib.import_module('MRF.models.'+dropdownmodel.value)
    net = MOD.model(ghost=True)
    if net.start_by_projection:
        iproj.children[1].value = True
        iproj.children[2].options = Normalization.list()
        return {'mod': iproj.children[0].value,
                'Sbp': iproj.children[1].value,
                'Norm': iproj.children[2].value,
                'Ini': iproj.children[3].value,
                'Dim': iproj.children[4].value,
                'name': iproj.children[5].value}
    else:
        iproj.children[1].value = False
        iproj.children[2].options = Normalization.list()[:2] 
        return {'mod': iproj.children[0].value,
                'Sbp': iproj.children[1].value,
                'Norm': iproj.children[2].value}

def for_proj(valid):
    if valid.new:
        new_iproj = widgets.interactive(print_form_proj, mod=dropdownmodel, sbp=start_by_projection, norm=normalization, ini=initialization, dim=dimension_projection, name=namepca)
        iproj.children = new_iproj.children
    else:
        new_iproj = widgets.interactive(print_form_proj, mod=dropdownmodel, sbp=start_by_projection, norm=normalization)
        iproj.children = new_iproj.children
start_by_projection.observe(for_proj, 'value')
PROJ = widgets.Output()
PROJ.clear_output()
iproj = widgets.interactive(print_form_proj, mod=dropdownmodel, sbp=start_by_projection, norm=normalization, ini=initialization, dim=dimension_projection, name = namepca)
input_widgetsproj = widgets.HBox([iproj], layout=item_layout)
dashboardproj = widgets.VBox([input_widgetsproj, PROJ])
display(dashboardproj)



optimizer =widgets.Select(
    options=Optimizer.list(),
    value='SGD', 
    description='Initialization for the first linear layer:',
    disabled=False,
    style=style, layout=layout)
display(optimizer)

lr = widgets.BoundedFloatText(
    value=0.001,
    min=0.,
    max=100.,
    description='Learning Rate:',
    disabled=False,
    style=style, layout=layout)
display(lr)

noise_type = widgets.Select(
    options=NoiseType.list(),
    value='Standard',
    description='Noise type:',
    style=style, layout=layout)
display(noise_type)

noise_level = widgets.BoundedFloatText(
    value=1./30.,
    min=0.000000000001,
    max=200.,
    description='Noise Level:',
    style=style, layout=layout)
display(noise_level)

batch_size = widgets.IntText(
    value=64,
    description='Batch Size:',
    disabled=False,
    style=style, layout=layout)
display(batch_size)

nb_epochs = widgets.IntText(
    value=10000,
    description='Nb Epochs:',
    disabled=False,
    style=style, layout=layout)
display(nb_epochs)

params_name = widgets.Select(
    options=nametoparam.keys(),
    value='The three parameters', 
    description='Parameters to estimate:',
    disabled=False,
    style=style, layout=layout)
display(params_name)

lossm0s = widgets.Select(
    options=Loss.list(),
    value='MSE-Log', 
    description='Loss m0s:',
    disabled=False,
    style=style, layout=layout)
display(lossm0s)

losst1f = widgets.Select(
    options=Loss.list(),
    value='MSE-Log', 
    description='Loss t1f:',
    disabled=False,
    style=style, layout=layout)
display(losst1f)

losst2f = widgets.Select(
    options=Loss.list(),
    value='MSE-Log', 
    description='Loss t2f:',
    disabled=False,
    style=style, layout=layout)
display(losst2f)

lossr = widgets.Select(
    options=Loss.list(),
    value='MSE-Log', 
    description='Loss r:',
    disabled=False,
    style=style, layout=layout)
display(lossr)

losst1s = widgets.Select(
    options=Loss.list(),
    value='MSE-Log', 
    description='Loss t1s:',
    disabled=False,
    style=style, layout=layout)
display(losst1s)

losst2s = widgets.Select(
    options=Loss.list(),
    value='MSE-Log', 
    description='Loss t2s:',
    disabled=False,
    style=style, layout=layout)
display(losst2s)

minPD = widgets.FloatText(
    description='minimun proton density:',
    value=0.1,
    disabled=False,
    style=style, layout=layout)
display(minPD)

maxPD = widgets.FloatText(
    description='maximum proton density:',
    value=1.,
    disabled=False,
    style=style, layout=layout)
display(maxPD)

validation = widgets.Checkbox(
    value=True,
    description='Keep validation error',
    disabled=False, layout=layout)

iscomplex = widgets.Checkbox(
    value=True,
    description='Check if input signals are complex valued.',
    disabled=False, layout=layout)
display(iscomplex)

validation_size = widgets.IntText(
    value=10000,
    description='Validation  size:',
    disabled=False,
    style={'description_width': '400px'} , layout=layout)

small_validation_size = widgets.IntText(
    value=10000,
    description='Validation size for a small dataset:',
    disabled=False,
    style={'description_width': '400px'} , layout=layout)

nb_files = widgets.IntText(
    value=0,
    description='Number of files used',
    disabled=False, layout=layout, style=style
)
display(nb_files)

path_files = widgets.Text(
                    value='gaussian',
                    description='Folder in "loading_data" containing the files:', style=style, layout=layout)
display(path_files)

save_name = widgets.Text(
    value='1',
    placeholder='Enter the name desired',
    description='Saving name:',
    disabled=False, layout=layout, style=style
)
display(save_name)

button = widgets.Button(description="CLICK HERE to save the settings previously defined for the training !", layout=widgets.Layout(width='50%', height='80px'))
output_params = widgets.Output()

def on_button_clicked(b):
    with output_params:
        dic = {}
        dic['optimizer'] = optimizer.value
        dic['lr'] = lr.value
        dic['model'] = dropdownmodel.value
        dic['noise_type'] = noise_type.value
        dic['noise_level'] = noise_level.value
        dic['normalization'] = normalization.value
        dic['namepca'] = namepca.value
        dic['loss'] = [lossm0s.value, losst1f.value, losst2f.value, lossr.value, losst1s.value, losst2s.value]
        dic['batch_size'] = batch_size.value
        dic['start_by_projection'] = start_by_projection.value
        dic['nb_epochs'] = nb_epochs.value
        dic['params'] = nametoparam[params_name.value]
        dic['initialization'] = initialization.value
        dic['validation_size'] = validation_size.value
        dic['validation'] = validation.value
        dic['complex'] = iscomplex.value
        dic['small_validation_size'] = small_validation_size.value
        dic['minPD'] = minPD.value
        dic['maxPD'] = maxPD.value
        dic['save_name'] = save_name.value
        dic['nb_files'] = nb_files.value
        dic['path_files'] = path_files.value
        dic['dimension_projection'] = dimension_projection.value
        import pickle
        if not os.path.exists('../settings_files_offline'):
            os.mkdir('../settings_files_offline')
        f = open("../settings_files_offline/settings_"+save_name.value+".pkl","wb")
        pickle.dump(dic,f)
        f.close()
        print('Settings have been saved.')
button.on_click(on_button_clicked)


def print_form_valid(**func_kwargs):
    if i.children[0].value:
        return {'Valid': i.children[0].value,
                'Svs': i.children[1].value,
                'nfv': i.children[2].value}
    else:
        return {'Valid': i.children[0].value}
def for_valid(valid):
    if valid.new:
        new_i = widgets.interactive(print_form_valid, valid=validation, svs=small_validation_size, nfv=validation_size)
        i.children = new_i.children
    else:
        new_i = widgets.interactive(print_form_valid, valid=validation)
        i.children = new_i.children
validation.observe(for_valid, 'value')
VALID = widgets.Output()
VALID.clear_output()
i = widgets.interactive(print_form_valid, valid=validation, svs=small_validation_size, nfv=validation_size)
input_widgets = widgets.HBox([i], layout=item_layout)
dashboard = widgets.VBox([input_widgets, VALID])
display(dashboard)

display(button, output_params)

VBox(children=(HBox(children=(interactive(children=(Select(description='Select one architecture.', layout=Layo…

Select(description='Initialization for the first linear layer:', layout=Layout(width='700px'), options=('SGD',…

BoundedFloatText(value=0.001, description='Learning Rate:', layout=Layout(width='700px'), style=DescriptionSty…

Select(description='Noise type:', index=1, layout=Layout(width='700px'), options=('SNR', 'Standard'), style=De…

BoundedFloatText(value=0.03333333333333333, description='Noise Level:', layout=Layout(width='700px'), max=200.…

IntText(value=64, description='Batch Size:', layout=Layout(width='700px'), style=DescriptionStyle(description_…

IntText(value=10000, description='Nb Epochs:', layout=Layout(width='700px'), style=DescriptionStyle(descriptio…

Select(description='Parameters to estimate:', layout=Layout(width='700px'), options=('The three parameters', '…

Select(description='Loss m0s:', index=1, layout=Layout(width='700px'), options=('MSE-CRB', 'MSE-Log', 'MSE', '…

Select(description='Loss t1f:', index=1, layout=Layout(width='700px'), options=('MSE-CRB', 'MSE-Log', 'MSE', '…

Select(description='Loss t2f:', index=1, layout=Layout(width='700px'), options=('MSE-CRB', 'MSE-Log', 'MSE', '…

Select(description='Loss r:', index=1, layout=Layout(width='700px'), options=('MSE-CRB', 'MSE-Log', 'MSE', 'MS…

Select(description='Loss t1s:', index=1, layout=Layout(width='700px'), options=('MSE-CRB', 'MSE-Log', 'MSE', '…

Select(description='Loss t2s:', index=1, layout=Layout(width='700px'), options=('MSE-CRB', 'MSE-Log', 'MSE', '…

FloatText(value=0.1, description='minimun proton density:', layout=Layout(width='700px'), style=DescriptionSty…

FloatText(value=1.0, description='maximum proton density:', layout=Layout(width='700px'), style=DescriptionSty…

Checkbox(value=True, description='Are input signals complex valued?', layout=Layout(width='700px'))

IntText(value=0, description='Number of files used', layout=Layout(width='700px'), style=DescriptionStyle(desc…

Text(value='gaussian', description='Folder in "loading_data" containing the files:', layout=Layout(width='700p…

Text(value='1', description='Saving name:', layout=Layout(width='700px'), placeholder='Enter the name desired'…

VBox(children=(HBox(children=(interactive(children=(Checkbox(value=True, description='Keep validation error', …

Button(description='CLICK HERE to save the settings previously defined for the training !', layout=Layout(heig…

Output()