"""
(c) Stefano B. Blumberg and Paddy J. Slator, do not redistribute or modify

Code to replicate the ADC experiment (alongside matlab code - maybe translate to python?) <Add paper link>

Overview for cells:
    - Choose data size splits 2
    - Generate data examples 3-A/B/C
    - Data format for JOFSTO 4
    - Option to pass data directly, or save to disk and load 5-A/B
    - JOFSTO hyperparameters 6,7,8
    - Data normalization 9
"""


In [None]:
########## (1)
# Import modules, see requirements.txt for jofsto requirements, set global seed

import numpy as np
from jofsto_code.jofsto_main import return_argparser, run

import matplotlib.pyplot as plt

np.random.seed(0)  # Random seed for entire script

In [None]:
#Directories and filenames to save data

In [None]:
########## (2)
# Data split sizes

n_train = 1000  # No. training voxels, reduce for faster training speed
n_val = n_train // 10  # No. validations set voxels
n_test = n_train // 10  # No. test set voxels

In [None]:
########## (3-A)
# Create dummy, randomly generated (positive) data

# C_bar = 220
# M = 12  # Number of input measurements \bar{C}, Target regressors
# rand = np.random.lognormal  # Random genenerates positive
# train_inp, train_tar = rand(size=(n_train, C_bar)), rand(size=(n_train, M))
# val_inp, val_tar = rand(size=(n_val, C_bar)), rand(size=(n_val, M))
# test_inp, test_tar = rand(size=(n_test, C_bar)), rand(size=(n_test, M))


# #########
# #Generate data using an ADC model
maxb = 5000
minb = 0
nb = 192

C_bar = nb

bvals = np.linspace(minb,maxb,nb)

def adc(D,bvals):
    signals = np.exp(-bvals*D)
    return signals

n_samples = n_train + n_val + n_test
minD = 0.0001
maxD = 0.003
parameters = np.random.uniform(low=minD,high=maxD,size=n_samples)

signals = np.zeros((n_samples,nb),dtype = np.float32)
for i in range(0,n_samples):
    signals[i,:] = adc(parameters[i],bvals)



In [None]:
#add noise
def add_noise(data, scale=0.02):
    data_real = data + np.random.normal(scale=scale, size=np.shape(data))
    data_imag = np.random.normal(scale=scale, size=np.shape(data))
    data_noisy = np.sqrt(data_real**2 + data_imag**2)

    return data_noisy

signals = add_noise(signals)



In [None]:
#split signals/params into train/test/val - don't need to randomise as the parameters are already randomised
train_signals = signals[0:n_train,:]
val_signals = signals[n_train:(n_train + n_val),:]
test_signals = signals[(n_train + n_val):(n_train + n_val + n_test),:]

train_parameters = signals[0:n_train,:]
val_parameters = signals[n_train:(n_train + n_val),:]
test_parameters = signals[(n_train + n_val):(n_train + n_val + n_test),:]

train_inp = train_signals
train_tar = train_parameters
val_inp = val_signals
val_tar = val_parameters
test_inp = test_signals
test_tar = test_parameters


In [None]:
########## (4)
# Load data into JOFSTO format

# Data in JOFSTO format, \bar{C} measurements, M target regresors
data = dict(
    train=train_inp,  # Shape n_train x \bar{C}
    train_tar=train_tar,  # Shape n_train x M
    val=val_inp,  # Shape n_val x \bar{C}
    val_tar=val_tar,  # Shape n_val x M
    test=test_inp,  # Shape n_test x \bar{C}
    test_tar=test_tar,  # Shape n_test x M
)

#with open(os.path.dirname(__file__) + "/base.yaml", "r") as f:
with open("/home/blumberg/Bureau/z_Automated_Measurement/Code/base.yaml", "r") as f:
    jofsto_args =  yaml.safe_load(f)

In [None]:
########## (5-A)
# Option to save data to disk, and JOFSTO load

data_fil = "/Users/paddyslator/python/ED_MRI/adc_simulations.npy"  # Add path to save file
np.save(data_fil, data)
print("Saving data as", data_fil)
pass_data = None
jofsto_args.extend(["--data_fil", data_fil])


########## (5-B)
# Option to pass data to JOFSTO directly

pass_data = data

In [None]:
########## (6)
# Simplest version of JOFSTO, modifying the most important hyperparameters


# Decreasing feature subsets sizes for JOFSTO to consider
C_i_values = [C_bar, C_bar // 2, C_bar // 4, C_bar // 8, C_bar // 16]
C_i_values = [C_bar, C_bar // 2, C_bar // 4, C_bar // 8, C_bar // 16]
jofsto_args.extend(["--C_i_values"] + [str(val) for val in C_i_values])

# Feature subset sizess for JOFSTO evaluated on test data
C_i_eval = [C_bar // 2, C_bar // 4, C_bar // 8, C_bar // 16]
jofsto_args.extend(["--C_i_eval"] + [str(val) for val in C_i_eval])

# Scoring net C_bar -> num_units_score[0] -> num_units_score[1] ... -> C_bar units
num_units_score = [1000, 1000]
jofsto_args.extend(["--num_units_score"] + [str(val) for val in num_units_score])

# Task net C_bar -> num_units_task[0] -> num_units_task[1] ... -> M units
num_units_task = [1000, 1000]
jofsto_args.extend(["--num_units_task"] + [str(val) for val in num_units_task])

jofsto_args.extend(["--out_base", "/Users/paddyslator/python/ED_MRI/test1"])
jofsto_args.extend(["--proj_name", "adc"])
jofsto_args.extend(["--run_name", "test"])


args = parser.parse_args(jofsto_args)
run(args=args, pass_data=pass_data)


In [None]:
#load the JOFSTO output
JOFTSO_output = np.load("/Users/paddyslator/python/ED_MRI/test1/adc/results/test_all.npy", allow_pickle=True).item()

In [None]:
#load the CRLB optimised protocol
import scipy.io as sio
CRLB_ADC = sio.loadmat('/Users/paddyslator/MATLAB/adc_crlb/crlb_adc_optimised_protocol.mat')
bvals_CRLB_ADC = np.squeeze(CRLB_ADC['b_opt'])
#scale them
bvals_CRLB_ADC = bvals_CRLB_ADC * 1e3

In [None]:
#plot the JOFSTO and CRLB b-values
#all super-design b-values
plt.plot(bvals,adc(0.001,bvals),'o')
#JOFSTO chosen b-values
plt.plot(bvals[JOFTSO_output[C_i_values[-1]]['measurements']],adc(0.001,bvals[JOFTSO_output[C_i_values[-1]]['measurements']]),'x')
#CRLB chosen b-values
plt.plot(bvals_CRLB_ADC,adc(0.001,bvals_CRLB_ADC),'o')


In [None]:
#results[12]["test_output"][0,]
test_parameters.shape
#test_signals.shape

In [None]:
#JOFSTO_output[12]["test_output"].shape
#test_tar.shape
test_tar[:,0].shape
np.corrcoef(test_tar[:,0],JOFSTO_output[12]["test_output"][:,0])

In [None]:
from scipy.optimize import minimize

#fit the ADC model on the full acquisition, JOFSTO acquisition, CRLB acquisition

#simulate data for each acquisition
signals_crlb = np.zeros((n_samples,len(bvals_CRLB_ADC)))
signals_super = np.zeros((n_samples,len(bvals)))
signals_jofsto = np.zeros((n_samples,len(bvals[JOFTSO_output[C_i_values[-1]]['measurements']])))

#simulate some new parameters
parameters = np.random.uniform(low=minD,high=maxD,size=n_samples)

for i in range(0,n_samples):
    signals_crlb[i,:] = add_noise(adc(parameters[i],bvals_CRLB_ADC),scale=0.1)
    signals_super[i,:] = add_noise(adc(parameters[i],bvals),scale=0.1)
    signals_jofsto[i,:] = add_noise(adc(parameters[i],bvals[JOFTSO_output[C_i_values[-1]]['measurements']]),scale=0.1)
    

def objective_function(D,bvals,signals):
    return np.mean((signals - adc(D,bvals))**2)
    
    
# def adc(D,bvals):
#     signals = np.exp(-bvals*D)
#     return signals

Dstart = 0.001

fitted_parameters_crlb = np.zeros(n_samples)
fitted_parameters_super = np.zeros(n_samples)
fitted_parameters_jofsto = np.zeros(n_samples)

for i in range(0,n_samples):
    fitted_parameters_crlb[i] = minimize(objective_function, Dstart, args=(bvals_CRLB_ADC,signals_crlb[i,:]),method='Nelder-Mead').x
    fitted_parameters_super[i] = minimize(objective_function, Dstart, args=(bvals,signals_super[i,:]),method='Nelder-Mead').x
    fitted_parameters_jofsto[i] = minimize(objective_function, Dstart, args=(bvals[JOFTSO_output[C_i_values[-1]]['measurements']],signals_jofsto[i,:]),method='Nelder-Mead').x





In [None]:
plt.plot(parameters,fitted_parameters_crlb,'o',markersize=1)
plt.plot(parameters,fitted_parameters_super,'v',markersize=1)
plt.plot(parameters,fitted_parameters_jofsto,'x',markersize=1)

print("CRLB correlation: " + str(np.corrcoef(parameters,fitted_parameters_crlb)[0,1]))
print("super correlation " + str(np.corrcoef(parameters,fitted_parameters_super)[0,1]))
print("JOFSTO correlation " + str(np.corrcoef(parameters,fitted_parameters_jofsto)[0,1]))




In [None]:
np.corrcoef(parameters,fitted_parameters_crlb)[0,1]

In [None]:
i=10 

thing = minimize(objective_function, Dstart, args=(bvals_CRLB_ADC,signals_crlb[i,:]),method='Nelder-Mead').x

plt.plot(bvals_CRLB_ADC,  signals_crlb[i,:])
plt.plot(bvals_CRLB_ADC,adc(parameters[i],bvals_CRLB_ADC),'o')
plt.plot(bvals_CRLB_ADC,adc(fitted_parameters_crlb[i],bvals_CRLB_ADC),'x')

def objective_function(D,bvals,signals):
    print(signals)
    print(adc(D,bvals))
    return np.mean((signals - adc(D,bvals))**2)

print('ground truth: ' + str(parameters[i]))
print('fitted: ' + str(fitted_parameters_crlb[i]))

objective_function(fitted_parameters_crlb[i],bvals_CRLB_ADC,signals_crlb[i,:])

In [None]:
bvals[JOFTSO_output[C_i_values[-1]]['measurements']]

In [None]:
########## (7)
# Modify more JOFSTO hyperparameters, less important, may change results

# Fix score after epoch, E_1 in paper
epochs_fix_sigma = 25
jofsto_args.extend(["--epochs_fix_sigma", str(epochs_fix_sigma)])

# Progressively set score to be sample independent across no. epochs, E_2 - E_1 in paper
epochs_decay_sigma = 10
jofsto_args.extend(["--epochs_decay_sigma", str(epochs_decay_sigma)])

# Progressively modify mask across number epochs, E_3 - E_2 in paper
epochs_decay = 10
jofsto_args.extend(["--epochs_decay", str(epochs_decay)])

args = parser.parse_args(jofsto_args)
run(args=args, pass_data=pass_data)

In [None]:

########## (8)
# Deep learning training hyperparameters for inner loop

# Training epochs per step, set large to trigger early stopping
total_epochs = 10000
jofsto_args.extend(["--total_epochs", str(total_epochs)])

# Training learning rate
learning_rate = 0.0001
jofsto_args.extend(["--learning_rate", str(learning_rate)])

# Training batch size
batch_size = 1500
jofsto_args.extend(["--batch_size", str(batch_size)])

args = parser.parse_args(jofsto_args)
run(args=args, pass_data=pass_data)


In [None]:
thing = np.load('.npy',allow_pickle=True)

In [None]:
thing.item().keys()