In [2]:
%matplotlib inline
import os
import numpy as np
import matplotlib as mpl
import pickle

import readdy_learn.analyze.analyze as ana
import readdy_learn.analyze.basis as basis
from pathos.multiprocessing import Pool

import pynumtools.kmc as kmc

mpl.rcParams['figure.figsize'] = (16, 13)
import matplotlib.pyplot as plt
import scipy.signal as ss
from readdy_learn.example.regulation_network import RegulationNetwork
from readdy_learn.example.regulation_network import sample_lsq_rates
from readdy_learn.example.regulation_network import sample_along_alpha

# try to generate reg network s.t. LSQ does not work properly

### Case 1
- LSQ fits well but doesnt necessarily recover the rates (varying $\alpha$)
- In the limit of low noise (and same initial conditions), the least squares solution (almost) recovers almost the right reactions (sparsity pattern). Vary hyperparameters (alpha, ~~lambda~~)
- use more basis functions

- timestep = $8\cdot 10^{-3}$: seems O.K.
- timestep = $6\cdot 10^{-3}$: lsq still works reasonably well
- timestep = $12 \cdot 10^{-3}$: lsq begins to get worse
    - this almost works, probably need a few more frames
- timestep = $24 \cdot 10^{-3}$: lsq picks up on processes that are definitely not there, ignores some of the correct ones
    - could not reproduce correct rates with lasso

In [3]:
def plot_l1_errors(results, cutoff=0.):
    l1_errors = []
    l1_std = []
    keys_sorted = np.array([k for k in sorted(results.keys())])
    for key in keys_sorted:
        # shape: (n_realizations, basis funs)
        rates = np.array(results[key])
        # build difference w correct rates
        ratesdiff = np.abs(np.array([r - regulation_network.desired_rates for r in rates]))
        l1norms = np.array([np.sum(x) for x in ratesdiff])
        l1_errors.append(np.mean(l1norms))
        l1_std.append(np.std(l1norms))
    l1_errors = np.array(l1_errors)
    l1_std = np.array(l1_std)
    plt.fill_between(keys_sorted, l1_errors-l1_std, l1_errors+l1_std,
                     color='b', alpha=.5)
    #plt.errorbar(keys, l2_mean, yerr=l2_std)
    plt.plot(keys_sorted, l1_errors)
    plt.xscale('log')
    #plt.yscale('log')
    plt.xlabel(r'$\alpha$')
    plt.ylabel('L1 error')
    ix = np.argmin(l1_errors)
    return l1_errors[ix], ix
def get_regulation_network(timestep):
    regulation_network = RegulationNetwork()
    regulation_network.timestep = timestep
    regulation_network.realisations = 1.
    regulation_network.noise_variance = 0.
    regulation_network.initial_states = [regulation_network.initial_states[1]]
    analysis = regulation_network.generate_analysis_object(fname_prefix='case_1', fname_postfix='0')
    for i in range(len(regulation_network.initial_states)):
        analysis.generate_or_load_traj_lma(i, regulation_network.target_time,
                                           noise_variance=regulation_network.noise_variance,
                                           realizations=regulation_network.realisations)
        shape = analysis.trajs[i].counts.shape
        print("n_frames={}, n_species={}".format(*shape))
    regulation_network.compute_gradient_derivatives(analysis, persist=False)
    return regulation_network
def do(timestep, fname):
    regulation_network = get_regulation_network(timestep)
    alphas = np.logspace(-6, -4, num=200)
    result = sample_along_alpha(regulation_network, alphas=alphas)
    lsq_rates = analysis.least_squares(0, tol=1e-16, recompute=True, persist=False, verbose=False)
    ana.plot_rates_bar(regulation_network.desired_rates, lsq_rates)
    plt.title('LSQ')
    plt.show()
    error, ix = plot_l1_errors(result)
    plt.show()
    ana.plot_rates_bar(regulation_network.desired_rates, np.array(result[alphas[ix]]).squeeze())
    plt.title('regularized')
    plt.show()
    
    data = {
        'regularized_rates': result,
        'lsq_rates': lsq_rates
    }
    
    print("minimal l1 error: {} (ix {})".format(error, ix))
    with open(fname, 'wb') as f:
        pickle.dump(data, f)
def show_results(fname):
    print("loading results from {}....".format(fname))
    regulation_network = get_regulation_network(1e-3)
    alphas = np.logspace(-6, -4, num=200)
    
    with open(fname, 'rb') as f:
        data = pickle.load(f)
    
    result = data['regularized_rates']
    lsq_rates = data['lsq_rates']
    
    ana.plot_rates_bar(regulation_network.desired_rates, lsq_rates)
    plt.title('LSQ')
    plt.show()
    error, ix = plot_l1_errors(result)
    plt.show()
    ana.plot_rates_bar(regulation_network.desired_rates, np.array(result[alphas[ix]]).squeeze())
    plt.title('regularized')
    plt.show()
    print("minimal l1 error: {} (ix {})".format(error, ix))

# dt = 9e-3, 334 frames

In [None]:
do(9e-3, 'case_1_dt_9e-3.pickle')

n_frames=334, n_species=9


# dt = 8e-3, 375 frames

In [None]:
do(8e-3, 'case_1_dt_8e-3.pickle')

# dt = 12e-3, 250 frames

In [None]:
do(12e-3, 'case_1_dt_12e-3.pickle')

# dt = 16e-3, 188 frames

In [None]:
do(16e-3, 'case_1_dt_16e-3.pickle')

# dt = 24e-3, 125 frames

In [None]:
do(24e-3, 'case_1_dt_24e-3.pickle')