In [None]:
import optuna
import matplotlib.pyplot as plt
import os
import cedne
from cedne import utils
from cedne import simulator
import numpy as np
import json
import scipy.stats as ss
import psycopg2

In [None]:
# root_path = '/Users/sahilmoza/Documents/Codes/CEDNe/examples/python/cluster_plots/5406695'
storage = "postgresql://smoza@/cedne_optimization_optuna?host=/tmp&port=5433"
frozen_studies = optuna.storages.RDBStorage(storage).get_all_studies()
for frozen_study in frozen_studies:
    study = optuna.load_study(study_name=frozen_study.study_name, storage=storage)
    trials = study.trials
    trial_numbers = [t.number for t in trials if t.state == optuna.trial.TrialState.COMPLETE]
    loss_values = [t.value for t in trials if t.state == optuna.trial.TrialState.COMPLETE]

    # Plot loss over trials
    f, ax = plt.subplots(figsize=(5,3))
    ax.plot(trial_numbers, loss_values, marker="o", linestyle="-", color="gray")
    ax.set_yscale("log")
    ax.set_xlabel("Trial Number")
    ax.set_ylabel("Log Loss")
    # ax.set_title(f"{study_name} Loss over Trials")
    utils.simpleaxis(ax)
    plt.show()

In [None]:
frozen_study.study_name

In [None]:
jsons = {}
for js in os.listdir('/Users/sahilmoza/Documents/Postdoc/Yun Zhang/data/SteveFlavell-NeuroPAL-Cell/Control/'):
    with open ("/Users/sahilmoza/Documents/Postdoc/Yun Zhang/data/SteveFlavell-NeuroPAL-Cell/Control/{}".format(js), 'r') as f:
        jsons['/n/home05/smoza/CEDNe/data_sources/downloads/Atanas_2023/Control/' + js+'_50_25_64'] = json.load(f)

In [None]:
best_params

In [None]:
best_params = {study.study_name: study.best_params}
neurons = {}
for study_name in best_params.keys():
    nrs = []
    params = best_params[study_name]
    for key in params.keys():
        if key.startswith('weight'):
            _, n1, n2, _ = key.split(':')
            nrs.append(n1)
            nrs.append(n2)
            nrs = list(set(nrs))
    neurons[study_name] = nrs

database = study.study_name
nn_chem_sub = nn_chem.subnetwork(neuron_names=neurons[database])
for neuron in nn_chem_sub.neurons:
    if neuron in neurons[database]:
        nn_chem_sub.neurons[neuron].set_property('amplitude', jsons[database]['trace_array'][measuredNeurons[database][neuron]])

num_timepoints = len(jsons[database]['trace_array'][measuredNeurons[database][list(measuredNeurons[database].keys())[0]]])
input_nodes = [nn_chem_sub.neurons[n] for n in nn_chem_sub.neurons if nn_chem_sub.neurons[n].type == 'sensory']
inputs = []
time_points = np.arange(num_timepoints)
for inp in input_nodes:
    if hasattr(inp, 'amplitude'):
        input_value = {t:inp.amplitude[j] for j,t in enumerate(time_points)}
        inputs.append(simulator.TimeSeriesInput([inp], input_value))

params = parse_parameters(best_params[database])

baseline = {nn_chem_sub.neurons[n]:0 for n in nn_chem_sub.neurons}
gains = {nn_chem_sub.neurons[n]:1 for n in nn_chem_sub.neurons}
time_constants = {nn_chem_sub.neurons[n]:1 for n in nn_chem_sub.neurons}
weights = {(nn_chem_sub.neurons[e[0].name], nn_chem_sub.neurons[e[1].name]):1 for e in nn_chem_sub.edges}

baseline.update({nn_chem_sub.neurons[n]:v for n,v in params['baseline'].items()})
gains.update({nn_chem_sub.neurons[n]:v for n,v in params['gain'].items()})
time_constants.update({nn_chem_sub.neurons[n]:v for n,v in params['time_constant'].items()})
weights.update({(nn_chem_sub.neurons[e[0]],nn_chem_sub.neurons[e[1]]):v for e,v in params['weight'].items()})

rate_model = simulator.RateModel(nn_chem_sub, input_nodes, weights, gains, time_constants, baseline, static_neurons=input_nodes, \
                                        time_points=time_points, inputs=inputs)
rate_model.time_points = time_points
res = rate_model.simulate()

f, ax = plt.subplots(figsize=(10,2*len(res.keys())), nrows=len(res.keys()), sharex=True, layout='constrained')
# for k, (n, node) in enumerate(nodelist):
for j,k in enumerate(res.keys()):
    utils.simpleaxis(ax[j])
    if hasattr(nn_chem_sub.neurons[str(k.name)], 'amplitude'):
        ax[j].plot(np.arange(num_timepoints), np.array(nn_chem_sub.neurons[str(k.name)].amplitude), color='gray')
        ax[j].set_title(f'{np.corrcoef(np.array(nn_chem_sub.neurons[str(k.name)].amplitude)[np.arange(num_timepoints)], res[k])[0,1]}')
    # ax[j].plot(time_points, np.array(nn_chem_sub.neurons[str(k.name)].amplitude)[time_points], label=f'{k.name}-{nn_chem_sub.neurons[str(k.name)].name}', color='gray')
    ax1 = ax[j]
    ax1.plot(np.arange(num_timepoints), res[k], color='orange', label=f'{k.name}-{nn_chem_sub.neurons[str(k.name)].name}')
    ax1.legend(frameon=False)
f.suptitle(f'{database}')
plt.show()

In [None]:
# root_path = '/Users/sahilmoza/Documents/Codes/CEDNe/examples/python/data_sources/downloads/Atanas_2023/Control'
for study_name in os.listdir(f"{root_path}"):

    DB_PATH = f'sqlite:///{root_path}/{study_name}/cedne_optimization_optuna.db'
    study_names = optuna.study.get_all_study_names(storage=DB_PATH)
    print("Available studies:", study_names)
    print(study_name, DB_PATH)
    study = optuna.load_study(study_name=f'../../data_sources/downloads/Atanas_2023/Control/{study_name}', storage=DB_PATH)
    # Extract trial data
    trials = study.trials
    trial_numbers = [t.number for t in trials if t.state == optuna.trial.TrialState.COMPLETE]
    loss_values = [t.value for t in trials if t.state == optuna.trial.TrialState.COMPLETE]
    
    # Plot loss over trials
    f, ax = plt.subplots(figsize=(5,3))
    ax.plot(trial_numbers, loss_values, marker="o", linestyle="-", color="gray")
    ax.set_yscale("log")
    ax.set_xlabel("Trial Number")
    ax.set_ylabel("Log Loss")
    ax.set_title(f"{study_name} Loss over Trials")
    utils.simpleaxis(ax)
    plt.show()


In [None]:
best_params = {}
root_path = '/Users/sahilmoza/Documents/Codes/CEDNe/examples/python/data_sources/downloads/Atanas_2023/Control'
for study_name in os.listdir(f"{root_path}"):
    DB_PATH = f'sqlite:///{root_path}/{study_name}/cedne_optimization_optuna.db'
    # Load study from database
    print(study_name, DB_PATH)
    study = optuna.load_study(study_name=f'../../data_sources/downloads/Atanas_2023/Control/{study_name}', storage=DB_PATH)
    # Extract trial data
    best_params[study_name] = study.best_params

In [None]:
all_keys = []
for study in best_params:
    all_keys.append(best_params[study].keys())
all_keys = set(all_keys[0]).union(*all_keys)

In [None]:
gain_list = {}
tconst_list = {}
base_list = {}
weight_list = {}
for j,key in enumerate(all_keys):
    pref = key.split(":")[0]
    if pref == 'gain':
        suff = key.split(":")[1]
        paramlist = []
        for study in best_params:
            if key in best_params[study]:
                paramlist.append(best_params[study][key])
        gain_list[suff] = paramlist
    elif pref == 'time_constant':
        suff = key.split(":")[1]
        paramlist = []
        for study in best_params:
            if key in best_params[study]:
                paramlist.append(best_params[study][key])
        tconst_list[suff] = paramlist
    elif pref == 'baseline':
        suff = key.split(":")[1]
        paramlist = []
        for study in best_params:
            if key in best_params[study]:
                paramlist.append(best_params[study][key])
        base_list[suff] = paramlist
    elif pref == 'weight':
        suff = '->'.join(key.split(":")[1:-1])
        paramlist = []
        for study in best_params:
            if key in best_params[study]:
                paramlist.append(best_params[study][key])
        weight_list[suff] = paramlist
    else:
        print(f"Unknown parameter: {key}")

###############################

In [None]:
# Path to your Optuna SQLite database
CEDNE_ROOT = os.path.dirname(os.path.abspath(cedne.__file__))
PACKAGE_ROOT = CEDNE_ROOT.split('src')[0]

for study_name in os.listdir(f"{PACKAGE_ROOT}/tmp"):
    if study_name.startswith("Atanas"):        
        DB_PATH = f"sqlite:///{PACKAGE_ROOT}/tmp/{study_name}/cedne_optimization_optuna.db"  # Replace with your database path
         # Load study from database
        print(study_name, DB_PATH)
        study = optuna.load_study(study_name=study_name, storage=DB_PATH)
        # Extract trial data
        trials = study.trials
        trial_numbers = [t.number for t in trials if t.state == optuna.trial.TrialState.COMPLETE]
        loss_values = [t.value for t in trials if t.state == optuna.trial.TrialState.COMPLETE]

        # Plot loss over trials
        f, ax = plt.subplots(figsize=(5,3))
        ax.plot(trial_numbers, loss_values, marker="o", linestyle="-", color="gray")
        ax.set_yscale("log")
        ax.set_xlabel("Trial Number")
        ax.set_ylabel("Log Loss")
        ax.set_title(f"{study_name} Loss over Trials")
        utils.simpleaxis(ax)
        plt.show()

In [None]:
best_params = {}
for study_name in os.listdir(f"{PACKAGE_ROOT}/tmp"):
    if study_name.startswith("Atanas"):        
        DB_PATH = f"sqlite:///{PACKAGE_ROOT}/tmp/{study_name}/cedne_optimization_optuna.db"  # Replace with your database path
         # Load study from database
        print(study_name, DB_PATH)
        study = optuna.load_study(study_name=study_name, storage=DB_PATH)
        # Extract trial data
        best_params[study_name] = study.best_params

In [None]:
best_params

In [None]:
all_keys = []
for study in best_params:
    all_keys.append(best_params[study].keys())

common_keys = set(all_keys[0]).intersection(*all_keys)

In [None]:
gain_list = {}
tconst_list = {}
base_list = {}
weight_list = {}
for j,key in enumerate(common_keys):
    pref = key.split(":")[0]
    if pref == 'gain':
        suff = key.split(":")[1]
        paramlist = []
        for study in best_params:
            paramlist.append(best_params[study][key])
        gain_list[suff] = paramlist
    elif pref == 'time_constant':
        suff = key.split(":")[1]
        paramlist = []
        for study in best_params:
            paramlist.append(best_params[study][key])
        tconst_list[suff] = paramlist
    elif pref == 'baseline':
        suff = key.split(":")[1]
        paramlist = []
        for study in best_params:
            paramlist.append(best_params[study][key])
        base_list[suff] = paramlist
    elif pref == 'weight':
        suff = '->'.join(key.split(":")[1:-1])
        paramlist = []
        for study in best_params:
            paramlist.append(best_params[study][key])
        weight_list[suff] = paramlist
    else:
        print(f"Unknown parameter: {key}")

In [None]:
#### Plotting the results

In [None]:
f, ax = plt.subplots(1, 1, figsize=(60, 3), layout='constrained')
for j,key in enumerate(sorted(gain_list.keys())):
    ax.scatter([j]*len(gain_list[key]), gain_list[key], color='gray', alpha=0.2)
    ax.errorbar([j], y= np.mean(gain_list[key]), yerr=np.std(gain_list[key]), color='k', alpha=1, fmt='o')
ax.set_xticks(range(len(gain_list)))
ax.set_xticklabels(sorted(gain_list.keys()), rotation=45, fontsize='xx-large', ha='right')
ax.tick_params(axis='y', labelsize='xx-large')
utils.simpleaxis(ax)
f.suptitle('Gain', fontsize='xx-large')
plt.show()

f, ax = plt.subplots(1, 1, figsize=(60, 3), layout='constrained')
for j,key in enumerate(sorted(tconst_list.keys())):
    ax.scatter([j]*len(tconst_list[key]), tconst_list[key], color='gray', alpha=0.2)
    ax.errorbar([j], y= np.mean(tconst_list[key]), yerr=np.std(tconst_list[key]), color='k', alpha=1, fmt='o')
ax.set_xticks(range(len(tconst_list)))
ax.set_xticklabels(sorted(tconst_list.keys()), rotation=45, fontsize='xx-large', ha='right')
ax.tick_params(axis='y', labelsize='xx-large')
utils.simpleaxis(ax)
f.suptitle('Time Constant', fontsize='xx-large')
plt.show()

f, ax = plt.subplots(1, 1, figsize=(60, 3), layout='constrained')
for j,key in enumerate(sorted(base_list.keys())):
    ax.scatter([j]*len(base_list[key]), base_list[key], color='gray', alpha=0.2)
    ax.errorbar([j], y= np.mean(base_list[key]), yerr=np.std(base_list[key]), color='k', alpha=1, fmt='o')
ax.set_xticks(range(len(base_list)))
ax.set_xticklabels(sorted(base_list.keys()), rotation=45, fontsize='xx-large', ha='right')
ax.tick_params(axis='y', labelsize='xx-large')
utils.simpleaxis(ax)
f.suptitle('Baseline', fontsize='xx-large')
plt.show()


In [None]:
thres_val = 1
rowwise = {}
nrows = len(weight_list.keys())//100+1
f, ax = plt.subplots(nrows, 1, figsize=(60, 3*nrows), layout='constrained')
for j,key in enumerate(sorted(weight_list.keys())):
    if not j//100 in rowwise:
        rowwise[j//100] = []
    ax[j//100].scatter([j%100]*len(weight_list[key]), weight_list[key], color='gray', alpha=0.2)
    if np.mean(weight_list[key])>thres_val:
        ax[j//100].errorbar([j%100], y= np.mean(weight_list[key]), yerr=np.std(weight_list[key]), color='orange', alpha=1, fmt='o')
    elif np.mean(weight_list[key])<-thres_val:
        ax[j//100].errorbar([j%100], y= np.mean(weight_list[key]), yerr=np.std(weight_list[key]), color='purple', alpha=1, fmt='o')
    else:
        ax[j//100].errorbar([j%100], y= np.mean(weight_list[key]), yerr=np.std(weight_list[key]), color='k', alpha=1, fmt='o')
    rowwise[j//100].append(key)

for key in rowwise:
    ax[key].axhline(0, color='gray', linestyle='--', alpha=0.5)
    ax[key].set_xticks(range(len(rowwise[key])))
    ax[key].set_xticklabels(rowwise[key], rotation=45, fontsize='xx-large', ha='right')
    # ax[key].set_yticklabels(ax[key].get_yticks(), fontsize='xx-large')
    utils.simpleaxis(ax[key])
    ax[key].tick_params(axis='y', labelsize='xx-large')
f.suptitle('Weights')
plt.show()

In [None]:
thres_val = 1
rowwise = {}
nrows = len(weight_list.keys())//100+1
f, ax = plt.subplots(nrows, 1, figsize=(60, 3*nrows), layout='constrained')
for j,(key,val) in enumerate(sorted(weight_list.items(),key= lambda x: x[0].split('->')[1])):
    if not j//100 in rowwise:
        rowwise[j//100] = []
    ax[j//100].scatter([j%100]*len(weight_list[key]), weight_list[key], color='gray', alpha=0.2)
    if np.mean(weight_list[key])>thres_val:
        ax[j//100].errorbar([j%100], y= np.mean(weight_list[key]), yerr=np.std(weight_list[key]), color='orange', alpha=1, fmt='o')
    elif np.mean(weight_list[key])<-thres_val:
        ax[j//100].errorbar([j%100], y= np.mean(weight_list[key]), yerr=np.std(weight_list[key]), color='purple', alpha=1, fmt='o')
    else:
        ax[j//100].errorbar([j%100], y= np.mean(weight_list[key]), yerr=np.std(weight_list[key]), color='k', alpha=1, fmt='o')
    rowwise[j//100].append(key)

for key in rowwise:
    ax[key].axhline(0, color='gray', linestyle='--', alpha=0.5)
    ax[key].set_xticks(range(len(rowwise[key])))
    ax[key].set_xticklabels(rowwise[key], rotation=45, fontsize='xx-large', ha='right')
    # ax[key].set_yticklabels(ax[key].get_yticks(), fontsize='xx-large')
    utils.simpleaxis(ax[key])
    ax[key].tick_params(axis='y', labelsize='xx-large')
f.suptitle('Weights')
plt.show()

In [None]:
## Simulating the best parameter models. 

In [None]:
neurons = {}
for study_name in best_params.keys():
    nrs = []
    params = best_params[study_name]
    for key in params.keys():
        if key.startswith('weight'):
            _, n1, n2, _ = key.split(':')
            nrs.append(n1)
            nrs.append(n2)
            nrs = list(set(nrs))
    neurons[study_name] = nrs

In [None]:
def parse_parameters(params):
    parsed_params = {}
    for key in params.keys():
        if key.startswith('gain'):
            _, n = key.split(':')
            if 'gain' not in parsed_params:
                parsed_params['gain'] = {}
            parsed_params['gain'][n] = params[key]
        elif key.startswith('time_constant'):
            _, n = key.split(':')
            if 'time_constant' not in parsed_params:
                parsed_params['time_constant'] = {}
            parsed_params['time_constant'][n] = params[key]
        elif key.startswith('baseline'):
            _, n = key.split(':')
            if 'baseline' not in parsed_params:
                parsed_params['baseline'] = {}
            parsed_params['baseline'][n] = params[key]
        elif key.startswith('weight'):
            _, n1, n2, _ = key.split(':')
            if 'weight' not in parsed_params:
                parsed_params['weight'] = {}
            parsed_params['weight'][(n1,n2)] = params[key]
    return parsed_params

In [None]:
jsons = {}
for js in os.listdir('/Users/sahilmoza/Documents/Postdoc/Yun Zhang/data/SteveFlavell-NeuroPAL-Cell/Control/'):
    with open ("/Users/sahilmoza/Documents/Postdoc/Yun Zhang/data/SteveFlavell-NeuroPAL-Cell/Control/{}".format(js), 'r') as f:
        jsons[js+'_100_50_25'] = json.load(f)

In [None]:
w = utils.makeWorm(chem_only=True)
nn_chem = w.networks["Neutral"]

In [None]:
measuredNeurons = {}
optim_neurs = {js:[] for js in jsons.keys()}
for js, p in jsons.items():
    sortedKeys = sorted ([int(x) for x in (p['labeled'].keys())])
    labelledNeurons = {p['labeled'][str(x)]['label']:x for x in sortedKeys if not '?' in p['labeled'][str(x)]['label']} # Removing unsure hits
    measuredNeurons[js] = {m:i for i,m in enumerate(set(labelledNeurons))}

In [None]:
for datab_in, database in enumerate(jsons.keys()):
    nn_chem_sub = nn_chem.subnetwork(neuron_names=neurons[database])
    for neuron in nn_chem_sub.neurons:
        if neuron in neurons[database]:
            nn_chem_sub.neurons[neuron].set_property('amplitude', jsons[database]['trace_array'][measuredNeurons[database][neuron]])

    num_timepoints = len(jsons[database]['trace_array'][measuredNeurons[database][list(measuredNeurons[database].keys())[0]]])
    input_nodes = [nn_chem_sub.neurons[n] for n in nn_chem_sub.neurons if nn_chem_sub.neurons[n].type == 'sensory']
    inputs = []
    time_points = np.arange(num_timepoints)
    for inp in input_nodes:
        if hasattr(inp, 'amplitude'):
            input_value = {t:inp.amplitude[j] for j,t in enumerate(time_points)}
            inputs.append(simulator.TimeSeriesInput([inp], input_value))

    params = parse_parameters(best_params[database])

    baseline = {nn_chem_sub.neurons[n]:0 for n in nn_chem_sub.neurons}
    gains = {nn_chem_sub.neurons[n]:1 for n in nn_chem_sub.neurons}
    time_constants = {nn_chem_sub.neurons[n]:1 for n in nn_chem_sub.neurons}
    weights = {(nn_chem_sub.neurons[e[0].name], nn_chem_sub.neurons[e[1].name]):1 for e in nn_chem_sub.edges}

    baseline.update({nn_chem_sub.neurons[n]:v for n,v in params['baseline'].items()})
    gains.update({nn_chem_sub.neurons[n]:v for n,v in params['gain'].items()})
    time_constants.update({nn_chem_sub.neurons[n]:v for n,v in params['time_constant'].items()})
    weights.update({(nn_chem_sub.neurons[e[0]],nn_chem_sub.neurons[e[1]]):v for e,v in params['weight'].items()})
    
    rate_model = simulator.RateModel(nn_chem_sub, input_nodes, weights, gains, time_constants, baseline, static_neurons=input_nodes, \
                                            time_points=time_points, inputs=inputs)
    rate_model.time_points = time_points
    res = rate_model.simulate()

    f, ax = plt.subplots(figsize=(10,2*len(res.keys())), nrows=len(res.keys()), sharex=True, layout='constrained')
    # for k, (n, node) in enumerate(nodelist):
    for j,k in enumerate(res.keys()):
        utils.simpleaxis(ax[j])
        if hasattr(nn_chem_sub.neurons[str(k.name)], 'amplitude'):
            ax[j].plot(np.arange(num_timepoints), np.array(nn_chem_sub.neurons[str(k.name)].amplitude), color='gray')
            ax[j].set_title(f'{np.corrcoef(np.array(nn_chem_sub.neurons[str(k.name)].amplitude)[np.arange(num_timepoints)], res[k])[0,1]}')
        # ax[j].plot(time_points, np.array(nn_chem_sub.neurons[str(k.name)].amplitude)[time_points], label=f'{k.name}-{nn_chem_sub.neurons[str(k.name)].name}', color='gray')
        ax1 = ax[j]
        ax1.plot(np.arange(num_timepoints), res[k], color='orange', label=f'{k.name}-{nn_chem_sub.neurons[str(k.name)].name}')
        ax1.legend(frameon=False)
    f.suptitle(f'{database}')
    plt.show()

In [None]:
for n in rate_model.neurons:
    print(n.name, n.type, rate_model.neurons[n].baseline, rate_model.neurons[n].gain, rate_model.neurons[n].time_constant)

In [None]:
for edata in rate_model.edges(data=True, keys=True):
    print(edata)

In [None]:
for e in nn_chem_sub.edges:
    print(weights[e])

In [None]:
weights

In [None]:
weight_list.keys()

In [None]:
utils.loadSynapticWeights(nn_chem)

In [None]:
leifer_weight = {}
for key in weight_list.keys():
    n1,n2 = key.split('->')
    c = nn_chem.connections[(nn_chem.neurons[n1], nn_chem.neurons[n2],0)]
    leifer_weight[key] = c.weight
    

In [None]:
fit_w = []
leifer_w = []
thres_val = 0.2
for key in weight_list.keys():
    if np.abs(leifer_weight[key])>thres_val and np.abs(np.mean(weight_list[key]))>thres_val:
        leifer_w.append(leifer_weight[key])
        fit_w.append(np.mean(weight_list[key]))
fit_w , leifer_w = np.array(fit_w), np.array(leifer_w)
f, ax = plt.subplots(figsize=(3,3))
ax.scatter(np.abs(fit_w), np.abs(leifer_w), color='gray', alpha=1)
slope, intercept, r_value, p_value, std_err = ss.linregress(np.abs(fit_w), np.abs(leifer_w))
x = np.linspace(0,2,100)
ax.plot(x, slope*x+intercept, color='k')
plt.show()

In [None]:
flat_w = [item for sublist in weight_list.values() for item in sublist]
flat_gain = [item for sublist in gain_list.values() for item in sublist]
flat_tconst = [item for sublist in tconst_list.values() for item in sublist]
flat_base = [item for sublist in base_list.values() for item in sublist]

plt.hist(flat_w)
plt.title('Weights')
plt.show()

plt.hist(flat_gain)
plt.title('Gains')
plt.show()

plt.hist(flat_tconst)
plt.title('Time Constants')
plt.show()

plt.hist(flat_base)
plt.title('Baselines')
plt.show()

In [None]:
3*len(gain_list) + len(weight_list)

In [None]:
from sklearn.decomposition import PCA
for study_name in os.listdir(f"{PACKAGE_ROOT}/tmp"):
    if study_name.startswith("Atanas"):        
        DB_PATH = f"sqlite:///{PACKAGE_ROOT}/tmp/{study_name}/cedne_optimization_optuna.db"  # Replace with your database path
         # Load study from database
        print(study_name, DB_PATH)
        study = optuna.load_study(study_name=study_name, storage=DB_PATH)
        # Extract trial data
        data_trials = np.array([list(trial.params.values()) for trial in study.trials if trial.value is not None])
        pca = PCA().fit(data_trials)
        explained_variance = np.cumsum(pca.explained_variance_ratio_)

        # Plot the variance explained by each principal component
        plt.plot(range(1, len(explained_variance)+1), explained_variance, marker='o')
        plt.xlabel("Number of Principal Components")
        plt.ylabel("Cumulative Explained Variance")
        plt.title("Effective Dimensionality of Optimized Parameters")
        plt.show()

In [None]:
data_trials