# Interactive tool for estimation of parameters in quantitative MRI

# Visualization of the results : $\color{red}{\text{OFFLINE FRAMEWORK}}$ 

## Visiting Student : Quentin Duchemin
## Profesors : Carlos Fernandez Granda & Jakob Assländer



### 1) Presentation of the project

We try to estimate the biological parameters in quantitative MRI using neural networks. This notebook is an interactive tool allowing you to visualize the results of your trainings. Make sure that you want to use fingerprints that are computed $\color{red}{\text{OFFLINE}}$.

### 2) How to use this interface ?

This tool allows to visualize the results of training. 

In [1]:
! pip install ipywidgets
! jupyter nbextension enable --py widgetsnbextension
! pip install seaborn
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets
import os
from os import listdir
from os.path import isfile, join
import importlib
from IPython.display import clear_output
import numpy as np
import torch
import seaborn as sns
import matplotlib.pyplot as plt
import scipy as sc
import scipy.io
import pickle
import sys
sys.path.append('..')
from MRF.Training_parameters import *
from MRF.BaseModel import *
from MRF.Projection import *
from MRF.models import *
import MRF
from MRF.Offline import Network, Data_class, Performances
from MRF.Training_parameters import *
clear_output()

# Visualize results

In [5]:
ALL = 'ALL'
def unique_sorted_values_plus_ALL(array):
    unique = array.unique().tolist()
    unique.sort()
    unique.insert(0, ALL)
    return unique
import ipywidgets as widgets
from IPython.display import display
fnetworks = [f for f in listdir('../save_networks_offline') if isfile(join('../save_networks_offline', f))]
exclude = ('num_files_validation','device','NN','urls','gradients','training_absolute_errors','training_relative_errors','losses_validation','params_validation','t','x','validation_absolute_errors','validation_relative_errors','small_validation_relative_errors','dico_validation','losses','losses_small_validation','losses_small_validation')

In [11]:
item_layout = widgets. Layout(
    display='flex',
    justify_content='space-between'
)
style = {'description_width': '250px'}
layout = {'width': '500px'}

def moving_average(a, n=3) :
    ret = np.cumsum(a, dtype=float)
    ret[n:] = ret[n:] - ret[:-n]
    return ret[n - 1:] / n

def print_inventory(dct):
    print("SETTINGS:")
    for item, amount in dct.items():
        print("{} ({})".format(item, amount))
        
def gram_schmidt(V):
    n,r = V.shape
    Vortho = np.zeros((n,r))
    Vortho[:,0] = V[:,0] / np.linalg.norm(V[:,0])
    for i in range(1,r):
        Vortho[:,i] = V[:,i]
        for j in range(i):
            Vortho[:,i] -= np.vdot(Vortho[:,j],V[:,i]) * Vortho[:,j]
        Vortho[:,i] /= np.linalg.norm(Vortho[:,i])
    return Vortho
    
def common_filtering(net, name, type_plot, mvavg=1, indice=1, folder='Offline', num=0, freqcut=10000, para_abs=2, para_ordo=1, para_error=1):
    model = importlib.import_module('MRF.models.'+net['name_model'])
    PLOT.clear_output()
    with open('../settings_files_offline/settings_'+name.replace('network_','')+'.pkl', 'rb') as f:
        settings = pickle.load(f)
        training_parameters = Training_parameters(settings['batch_size'], 1, settings['nb_epochs'], settings['params'], settings['normalization'])
        projection = Projection(settings['start_by_projection'], settings['dimension_projection'], settings['initialization'], settings['normalization'], settings['namepca'])
        data_class = Data_class(training_parameters, settings['noise_type'], settings['noise_level'], 
                                       settings['minPD'], settings['maxPD'], settings['nb_files'], settings['path_files'])
        validation_settings = {'validation': settings['validation'],'small_validation_size': settings['small_validation_size'], 'validation_size': settings['validation_size']}
        netw = model.model(projection=projection,nb_params=len(settings['params']))
        device = torch.device('cpu')
        netw.load_state_dict(net['NN'])
        netw.eval()
        estimation = Network(settings['model'], settings['loss'], training_parameters, settings['save_name'], data_class, validation_settings, projection=projection)
    with PLOT:
        if type_plot == 'settings':
            d = dict((k,net[k]) for k in net.keys() if k not in exclude)
            print_inventory(d)

        elif type_plot == 'first_layer':
            sns.set()
            layer = np.array(netw.fc1.weight.data)
            n,m = layer.shape
            c = 0
            fig, ax = plt.subplots(nrows=int((n+1)//2), ncols=2, figsize=(10,35))  
            for i in range(n):
                c += 1
                c = c % 2
                ax[i//2,c].plot(layer[i,:])
                ax[i//2,c].set_xlabel('time (in s.)')
            plt.show()

            button = widgets.Button(description="CLICK HERE to save those weights !", layout=widgets.Layout(width='50%', height='80px'))
            output = widgets.Output()
            display(button, output)
            def on_button_clicked(b):
                with output:
                    sc.io.savemat('weights_first_layer.mat',layer)
                    sc.io.savemat('weights_first_layer_orthonormal.mat',gram_schmidt(layer.T).T)
            button.on_click(on_button_clicked)
            
        elif type_plot == 'Basis functions for the projection subspace':
            
            from scipy.fftpack import fft, ifft, fftfreq
            
            layer = np.array(netw.fc1.weight.data)
            n,m = layer.shape
            u,s,vt = np.linalg.svd(layer)
            n,m = u.shape
            W = fftfreq(vt.shape[1], d=1) 
            c = 0
            fig, ax = plt.subplots(nrows=int((n+1)//2), ncols=2, figsize=(10,35))  
            for i in range(n):
                f_signal = fft(vt[i,:])
                # If our original signal time was in seconds, this is now in Hz    
                cut_f_signal = f_signal.copy()
                cut_f_signal[(np.abs(W)>freqcut)] = 0

                cut_signal = ifft(cut_f_signal)
                c += 1
                c = c % 2
                ax[i//2,c].plot(cut_signal)
                ax[i//2,c].set_xlabel('time (in s.)')
            plt.show()

        elif type_plot == 'NN VS NLLS and CRB':
            MEANnlls, STDnlls, MEANnet, STDnet, CRBs, trueparams = estimation.local_study(netw, os.path.join('../noise_files', folder))
            pos = [i+1 for i in range(len(estimation.trparas.params))]
            bars = ('CRB','NN', 'NLLS')
            params = estimation.trparas.params
            nbparams = len(params)
            if num > CRBs.shape[0]:
                print('No fingerprint exists for this number. To see only one fingerprint, choose a number between 0 and '+str(CRBs.shape[0]-1)+'. Here we plot all the fingerprints.')
                for numero in range(CRBs.shape[0]):
                    print('True parameters', trueparams[numero,:])
                    plt.figure(figsize=(4*nbparams,3))
                    
                    for j in range(nbparams):
                        plt.style.use('seaborn-white')
                        plt.subplot(1,nbparams,j+1)
                        plt.ylabel(paramtoname[params[j]]) # CRBs[numero,j]
                        plt.errorbar(1, trueparams[numero,params[j]],0.0001, fmt='o', color='black', ecolor='lightgray', elinewidth=3, capsize=0, label='CRB')
                        plt.errorbar(2, MEANnet[numero,j], STDnet[numero,j], fmt='o', color='red', ecolor='lightgray', elinewidth=3, capsize=0, label='NN')
                        plt.errorbar(3, MEANnlls[numero,j], STDnlls[numero,j], fmt='o', color='green', ecolor='lightgray', elinewidth=3, capsize=0, label='NLLS')
                        plt.xticks(pos, bars)
                        plt.tight_layout()
                    plt.show()
            else:
                plt.figure(figsize=(12,3))
                params = estimation.trparas.params
                nbparams = len(params)
                print('True parameters', trueparams[num,:])

                for j in range(nbparams):
                    plt.style.use('seaborn-white')
                    plt.subplot(1,nbparams,j+1)
                    plt.ylabel(paramtoname[params[j]])#CRBs[num,j]
                    plt.errorbar(1, trueparams[num,params[j]], 0.0001, fmt='o', color='black', ecolor='lightgray', elinewidth=3, capsize=0, label='CRB')
                    plt.errorbar(2, MEANnet[num,j], STDnet[num,j], fmt='o', color='red', ecolor='lightgray', elinewidth=3, capsize=0, label='NN')
                    plt.errorbar(3, MEANnlls[num,j], STDnlls[num,j], fmt='o', color='green', ecolor='lightgray', elinewidth=3, capsize=0, label='NLLS')
                    plt.xticks(pos, bars)
                    plt.tight_layout()
                plt.show()
                for j in range(nbparams):
                    print(paramtoname[params[j]]+' :',trueparams[num,params[j]])
                    print('                      sqrt{CRB} : ', round(CRBs[num,j],4) )
                    print('NEURAL NETWORK mean : ', round(MEANnet[num,j],4), '    std : ', round(STDnet[num,j],4), '    sqrt{CRB} / std   : ', round(CRBs[num,j]/np.float(STDnet[num,j]),4) )
                    print('NON LINEAR LS  mean : ', round(MEANnlls[num,j],4), '    std : ', round(STDnlls[num,j],4), '    sqrt{CRB} / std   : ',round( CRBs[num,j]/np.float(STDnlls[num,j]),4) )
                    print('\n')
                    print('\n')

        elif type_plot == 'relative errors':
            plt.figure(figsize=(12,5*len(net['params'])))
            plt.style.use('seaborn-whitegrid')
            for i in range(len(net['params'])):
                if len(net['params'])==1:
                    if net['validation']:
                        val_relerrors = moving_average(np.array(net['validation_relative_errors'])[indice:], n=mvavg)
                    train_relerrors = np.array(net['training_relative_errors'])[indice:]
                else:
                    if net['validation']:
                        val_relerrors = moving_average(np.array(net['validation_relative_errors'])[indice:,i], n=mvavg)
                    train_relerrors = np.array(net['training_relative_errors'])[indice:,i]
                train_relerrors = moving_average(train_relerrors, n=mvavg)
                size = train_relerrors.shape[0]
                plt.subplot(len(net['params']),1,i+1)
                absc = [mvavg+indice + j for j in range(size)]
                plt.plot(absc,train_relerrors, label='training')
                if net['validation']:
                    plt.plot(absc, val_relerrors, label='validation')
                plt.legend()
                plt.title('Relative error on '+paramtoname[net['params'][i]] +' along epochs',size=20)
                plt.xlabel('Epochs')
            plt.show()
            
        elif type_plot == 'absolute errors':
            plt.figure(figsize=(12,5*len(net['params'])))
            plt.style.use('seaborn-whitegrid')
            for i in range(len(net['params'])):
                if len(net['params'])==1:
                    if net['validation']:
                        val_relerrors = moving_average(np.array(net['validation_absolute_errors'])[indice:], n=mvavg)
                    train_relerrors = np.array(net['training_absolute_errors'])[indice:]
                else:
                    if net['validation']:
                        val_relerrors = moving_average(np.array(net['validation_absolute_errors'])[indice:,i], n=mvavg)
                    train_relerrors = np.array(net['training_absolute_errors'])[indice:,i]
                train_relerrors = moving_average(train_relerrors, n=mvavg)
                size = train_relerrors.shape[0]
                plt.subplot(len(net['params']),1,i+1)
                absc = [mvavg+indice + j for j in range(size)]
                plt.plot(absc,train_relerrors, label='training')
                if net['validation']:
                    plt.plot(absc, val_relerrors, label='validation')
                plt.legend()
                plt.title('absolute error on '+paramtoname[net['params'][i]] +' along epochs',size=20)
                plt.xlabel('Epochs')
            plt.show()
            
        elif type_plot == 'absolute errors over CRBs':
            plt.figure(figsize=(12,5*len(net['params'])))
            plt.style.use('seaborn-whitegrid')
            for i in range(len(net['params'])):
                if len(net['params'])==1:
                    if net['validation']:
                        val_relerrors = moving_average(np.array(net['validation_absolute_errors_over_CRBs'])[indice:], n=mvavg)
                    train_relerrors = np.array(net['training_absolute_errors_over_CRBs'])[indice:]
                else:
                    if net['validation']:
                        val_relerrors = moving_average(np.array(net['validation_absolute_errors_over_CRBs'])[indice:,i], n=mvavg)
                    train_relerrors = np.array(net['training_absolute_errors_over_CRBs'])[indice:,i]
                train_relerrors = moving_average(train_relerrors, n=mvavg)
                size = train_relerrors.shape[0]
                plt.subplot(len(net['params']),1,i+1)
                absc = [mvavg+indice + j for j in range(size)]
                plt.plot(absc,train_relerrors, label='training')
                if net['validation']:
                    plt.plot(absc, val_relerrors, label='validation')
                plt.legend()
                plt.title('absolute error on '+paramtoname[net['params'][i]] +' along epochs',size=20)
                plt.xlabel('Epochs')
            plt.show()
            
        elif type_plot == 'Singular values projection layer':
            layer = np.array(netw.fc1.weight.data)
            n,m = layer.shape
            sns.set()
            layer = np.array(netw.fc1.weight.data)
            u, s, vt = np.linalg.svd(layer)
            plt.plot(s)
            plt.xlabel('Number of the singular values')
            plt.ylabel('Singular values of the first linear layer')
            plt.title('Singular values of the projection layer', size=18)
            plt.show()

        elif type_plot == 'error':
            paraabs, paraordo, paraerror = nametoparam[para_abs][0], nametoparam[para_ordo][0], nametoparam[para_error][0]

            if paraerror not in net['params']:
                print('The parameter chosen for the error was not estimated for this training.')
            else:
                VP = net['params_validation']
                indices = np.where(VP[:,0]>=-0.3)[0]
                VP = VP[indices,:]
                VPerrors = np.zeros((VP.shape[0]))
                with torch.no_grad():
                    outputs = netw(net['dico_validation'])[indices,:]
                    ind = net['params'].index(paraerror)
                    VPerrors = np.abs(estimation.transform(outputs[:,ind], paraerror)-VP[:,paraerror])/VP[:,paraerror]
                VP, VPerrors = np.array(VP), np.array(VPerrors)
                args = np.argsort(VPerrors)
                sns.set_style("ticks")        
                fig , ax = plt.subplots(nrows=3, ncols=1, figsize=(10,18))
                sc = ax[0].scatter(VP[args[-100:],paraabs], VP[args[-100:],paraordo], c=VPerrors[args[-100:]], s=100, edgecolor='')
                plt.colorbar(sc, ax=ax[0])
                ax[0].set_xlabel(para_abs, fontsize=15)
                ax[0].set_ylabel(para_ordo, fontsize=15)
                ax[0].set_title('100 worst relative errors in validation for '+para_error, fontsize=20)
                sns.distplot(VP[args[-100:],paraerror], hist=True, kde=True, 
                             bins=int(180/5), color = 'darkblue', 
                             hist_kws={'edgecolor':'black'}, kde_kws={'linewidth': 4}, 
                             label='Keeping the 100 samples with highest relative error', ax=ax[1])
                sns.distplot(VP[:,paraerror], hist=True, kde=True, 
                             bins=int(180/5), color = 'darkred', 
                             hist_kws={'edgecolor':'black'}, kde_kws={'linewidth': 4},
                             label='On the whole validation dataset',ax=ax[1])
                ax[1].set_xlabel(para_error, fontsize=15)
                ax[1].set_title('PDF of '+para_error, fontsize=20)
                ax[1].legend(prop={'size': 15})
                
                
#                 bins = np.arange(0,350,1000)
#                 sns.distplot(net['CRBs_validation'][args[-100:],paraerror], hist=True, kde=True,  bins=int(180/5), color = 'blue', 
#                              hist_kws={'edgecolor':'black'}, kde_kws={'linewidth': 4}, 
#                              label='Keeping the 100 samples with highest relative error', ax=ax[2])
                
                
                #if net['CRBrequired']
#                 sns.distplot(net['CRBs_validation'][:,paraerror], hist=True, kde=True, 
#                              color = 'red', 
#                              hist_kws={'edgecolor':'black'}, kde_kws={'linewidth': 4},
#                              label='On the whole validation dataset',ax=ax[2])
                ax[2].set_xlabel('CRB', fontsize=15)
                ax[2].set_title('PDF of the CRB on '+para_error, fontsize=20)
                ax[2].legend(prop={'size': 15})
                fig.subplots_adjust(hspace=0.4)
                plt.tight_layout()
                plt.show()

        elif type_plot == 'gradients wrt loss':
            sns.set_style("ticks")        
            fig , ax = plt.subplots()
            plt.plot(net['losses'],net['gradients'])
            ax.set_xlabel('Training loss per epoch')
            ax.set_ylabel('Average of the gradient norm per epoch')
            plt.title('Norm of the gradients of the parameters wrt the training loss',size=15)
            plt.show()
            
        elif type_plot == 'samples':
            sampled_points = net['samples']
            plt.scatter(np.log10(sampled_points[:,2]),np.log10(sampled_points[:,1]))
            plt.xlabel('log(T2)')
            plt.ylabel('log(T1)')
            plt.title('Points sampled during training')
            plt.show()

        elif type_plot ==  'loss':
            plt.figure(figsize=(12,5))
            plt.style.use('seaborn-whitegrid')
            train_loss = (net['losses'])[indice:]
            train_loss = moving_average(train_loss, n=mvavg)
            size = train_loss.shape[0]
            val_loss = moving_average((net['losses_validation'])[indice:], n=mvavg)
            plt.subplot()
            absc = [mvavg+indice + j for j in range(size)]
            plt.plot(absc,train_loss, label='training')
            plt.plot(absc, val_loss, label='validation')
            plt.legend()
            plt.xlabel('Epochs')
            plt.title('Loss along epochs',size=20)
            plt.show()

PLOT = widgets.Output()
PLOT.clear_output()

options = {'settings', 'error', 'loss', 'relative errors', 'absolute errors','absolute errors over CRBs', 'NN VS NLLS and CRB', 'gradients wrt loss'}
options.update({'first_layer', 'Singular values projection layer', 'Basis functions for the projection subspace'})

dropdown_type = widgets.Dropdown(options = options, value='settings', description='Visualization : ', style=style, layout=layout)
dropdown_network = widgets.Dropdown(options = fnetworks  , description='Chosen network: ', style=style, layout=layout)
dropdown_indice = widgets.BoundedIntText(min=0, max=100000, value=1, step=1, description='Starting epoch: ', style=style, layout=layout)
mv_avg = widgets.BoundedIntText(min=1, max=100000, value=1, step=1, description='Number of epochs for moving average: ', style=style, layout=layout)
dropdown_folder = widgets.Dropdown(
                    options=next(os.walk('../noise_files'))[1],
                    description='Folder nlls:', style=style, layout=layout)
dropdown_num = widgets.BoundedIntText(
                    value=0,
                    description='Number of the fingerprint to study:',
                    disabled=False,
                    min=0,
                    style=style, layout=layout)
dropdown_freqcut = widgets.FloatText(
                    value=10000,
                    description='Cut frequency:',
                    disabled=False,
                    min=0,
                    style=style, layout=layout)
dropdown_para_abs = widgets.Select(
                    options=paramtoname.values(),
                    value='m0s', 
                    description='Parameter in the x axis:',
                    disabled=False,
                    style=style, layout=layout)

dropdown_para_ordo = widgets.Select(
                    options=paramtoname.values(),
                    value='T1', 
                    description='Parameter in the y axis:',
                    disabled=False,
                    style=style, layout=layout)

dropdown_para_error = widgets.Select(
                    options=paramtoname.values(),
                    value='m0s', 
                    description='Parameter for the error:',
                    disabled=False,
                    style=style, layout=layout)

def print_form(**func_kwargs):
    if i.children[1].value in ['relative errors','loss','absolute errors','absolute errors over CRBs']:
        return {'Network': i.children[0].value,
                'Type': i.children[1].value,
                'Mv Avg': i.children[2].value,
                'Indice': i.children[3].value}
    elif i.children[1].value == 'NN VS NLLS and CRB':
        return {'Network': i.children[0].value,
                'Type': i.children[1].value,
                'Folder': i.children[2].value,
                'Num': i.children[3].value}
    elif i.children[1].value == 'Basis functions for the projection subspace':
        return {'Network': i.children[0].value,
                'Type': i.children[1].value,
                'Freqcut': i.children[2].value}
    elif i.children[1].value == 'error':
        return {'Network': i.children[0].value,
                'Type': i.children[1].value,
                'paraabs': i.children[2].value,
                'paraordo': i.children[3].value,                
                'paraerror': i.children[4].value}

    else:
        return {'Network': i.children[0].value,
                'Type': i.children[1].value}
        

def for_mv_avg(plottype):
    if plottype.new in ['relative errors','loss', 'absolute errors','absolute errors over CRBs']:
        new_i = widgets.interactive(print_form, network=dropdown_network, plottype=dropdown_type, mvavg=mv_avg, indice=dropdown_indice)
        i.children = new_i.children
    elif plottype.new == 'NN VS NLLS and CRB':
        new_i = widgets.interactive(print_form, network=dropdown_network, plottype=dropdown_type, 
                                    folder=dropdown_folder, num=dropdown_num)
        i.children = new_i.children
    elif plottype.new == 'Basis functions for the projection subspace':
        new_i = widgets.interactive(print_form, network=dropdown_network, plottype=dropdown_type, 
                                    freqcut=dropdown_freqcut)
        i.children = new_i.children
    elif plottype.new == 'error':
        new_i = widgets.interactive(print_form, network=dropdown_network, plottype=dropdown_type, 
                                    para_abs=dropdown_para_abs, para_ordo=dropdown_para_ordo, para_error=dropdown_para_error)
        i.children = new_i.children
    else:
        new_i = widgets.interactive(print_form, network=dropdown_network, plottype=dropdown_type)
        i.children = new_i.children
dropdown_type.observe(for_mv_avg, 'value')

button = widgets.Button(description="CLICK HERE to change the network or plot !", layout=widgets.Layout(width='30%', height='80px'))
def on_button_clicked(b):
    net = torch.load(join('../save_networks_offline',dropdown_network.value),map_location='cpu')
    common_filtering(net, dropdown_network.value, dropdown_type.value, mvavg=mv_avg.value, indice=dropdown_indice.value,
                    folder=dropdown_folder.value, num=dropdown_num.value, freqcut = dropdown_freqcut.value,
                    para_abs=dropdown_para_abs.value, para_ordo=dropdown_para_ordo.value, para_error=dropdown_para_error.value)

i = widgets.interactive(print_form, network=dropdown_network, plottype=dropdown_type)
button.on_click(on_button_clicked)
input_widgets = widgets.HBox([i, button], layout=item_layout)
tab = widgets.Tab([PLOT], layout=item_layout)
dashboard = widgets.VBox([input_widgets, PLOT])
display(dashboard)

VBox(children=(HBox(children=(interactive(children=(Dropdown(description='Chosen network: ', layout=Layout(wid…

In [None]:
model = importlib.import_module('MRF.models.full_joint_deep3')

with open('../settings_files_offline/settings_32.pkl', 'rb') as f:
        settings = pickle.load(f)
        training_parameters = Training_parameters(settings['batch_size'], 1, settings['nb_epochs'], settings['params'], settings['normalization'])
        projection = Projection(settings['start_by_projection'], settings['dimension_projection'], settings['initialization'], settings['normalization'])
        data_class = Data_class(training_parameters, settings['noise_type'], settings['noise_level'], 
                                       settings['minPD'], settings['maxPD'], settings['url_file'])
        validation_settings = {'validation': settings['validation'],'small_validation_size': settings['small_validation_size'], 'validation_size': settings['validation_size']}
        netw = model.model(projection=projection,nb_params=len(settings['params']))
        device = torch.device('cpu')
        netw.load_state_dict(net['NN'])
        print(len(list(netw.parameters())))

In [None]:
net = torch.load(join('../save_networks_offline','network_crb-gauss-ligo'),map_location='cpu')
model = importlib.import_module('MRF.models.'+net['name_model'])
with open('../settings_files_offline/settings_crb-gauss-ligo.pkl', 'rb') as f:
    settings = pickle.load(f)
    training_parameters = Training_parameters(settings['batch_size'], 1, settings['nb_epochs'], settings['params'], settings['normalization'])
    projection = Projection(settings['start_by_projection'], settings['dimension_projection'], settings['initialization'], settings['normalization'])
    data_class = Data_class(training_parameters, settings['noise_type'], settings['noise_level'], 
                                   settings['minPD'], settings['maxPD'], settings['url_file'])
    validation_settings = {'validation': settings['validation'],'small_validation_size': settings['small_validation_size'], 'validation_size': settings['validation_size']}
    netw = model.model(projection=projection,nb_params=len(settings['params']))
    device = torch.device('cpu')
    netw.load_state_dict(net['NN'])
    netw.eval()
    
    


In [None]:
arg = np.argsort(np.array(net['CRBs_validation'][:,0]))

print(net['CRBs_validation'][arg,0])

plt.hist(net['params_validation'][arg[:50],0])



In [None]:
from MRF.simulate_signal import simulation
model = importlib.import_module('MRF.models.ligo')
net = torch.load(join('../save_networks_offline','network_crb-gauss-ligo'),map_location='cpu')


with open('../settings_files_offline/settings_crb-gauss-ligo.pkl', 'rb') as f:
    settings = pickle.load(f)
    training_parameters = Training_parameters(settings['batch_size'], 1, settings['nb_epochs'], settings['params'], settings['normalization'])
    projection = Projection(settings['start_by_projection'], settings['dimension_projection'], settings['initialization'], settings['normalization'])
    data_class = Data_class(training_parameters, settings['noise_type'], settings['noise_level'], 
                                   settings['minPD'], settings['maxPD'], settings['nb_files'], settings['path_files'])
    validation_settings = {'validation': settings['validation'],'small_validation_size': settings['small_validation_size'], 'validation_size': settings['validation_size']}
    netw = model.model(projection=projection,nb_params=len(settings['params']))
    device = torch.device('cpu')
    netw.load_state_dict(net['NN'])
    netw.eval()
    estimation = Network(settings['model'], settings['loss'], training_parameters, settings['save_name'], data_class, validation_settings, projection=projection)

    prms= [0.3,1,0.1,100,1e-2]

    y,dy = simulation.simulate_MT_ODE(data_class.x, data_class.TR, data_class.t, prms[0], prms[1], prms[2], prms[3], prms[1], prms[4])

    prms= [0.34,1,0.1,100,1e-2]

    y2,dy = simulation.simulate_MT_ODE(data_class.x, data_class.TR, data_class.t, prms[0], prms[1], prms[2], prms[3], prms[1], prms[4])


    plt.plot(y[:,0])    
    plt.plot(y2[:,0])
    
    print(  np.dot(y[:,0],y2[:,0]) / np.linalg.norm(y2[:,0]) / np.linalg.norm(y[:,0]))




In [None]:
a = torch.ones((5,3))

print(torch.norm(a, dim=1))

In [None]:
if 'Without':
    print('OK')