In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np

from src.utils import load_data, predict_and_plot_SLC
import seaborn as sns

sns.set_style("whitegrid")
sns.set_palette('colorblind')
mpl.rcParams['font.family'] = 'Arial'

In [2]:
# Load data
data = load_data([0,151])

# IMPORTANT: If you want to test the code you can use a small value here and it should run fast
# If you want to reproduce the results in the paper use sample_size=3000
sample_size = 10

In [None]:
def generate_predictions(data, timestamp, max_val_un, sample_size, mean_prior):
    """Generate predictions for different resolutions."""
    print(sample_size)
    print("A")
    predictions = [predict_and_plot_SLC(data, timestamp, max_val_un, i, plot=False,
                                      sample_size=sample_size, mean_prior=mean_prior)
                  for i in range(2,9)]
    print("V")
    data_for_plot = []
    for r, prediction in zip(range(2,9), predictions):
        mean, var = prediction[1]
        data_list = []
        for m, v in zip(mean, var):
            data_list += list(np.random.normal(m, np.sqrt(v), size=sample_size))

        x = []
        y = []
        for thres in range(0, max_val_un, 5):
            x.append(thres)
            y.append(len([i for i in data_list if i > thres]) / len(data_list))
        data_for_plot.append((r, data_list, x, y))

    return data_for_plot

def create_plots(timestamp, max_val_un, data_for_plot, mean_label, idx_mean):
    """Create and save PDF and CDF plots."""
    plt.figure(figsize=(10,14))

    ticksize = 22
    label = 32
    legend = 20

    plt.subplot(2,1,1)
    for r, data_list, x, y in data_for_plot:
        sns.kdeplot(np.array(data_list), label=str(r) + " km", gridsize=400)

    plt.subplot(2,1,2)
    for r, data_list, x, y in data_for_plot:
        plt.plot(x, [1 - i for i in y], label=str(r) + " km")

    plt.subplot(2,1,1)
    plt.xlim(0, max_val_un)
    plt.xlabel("SLC (mm)", fontsize=label)
    plt.ylabel("Density", fontsize=label)
    plt.xticks(fontsize=ticksize)
    plt.yticks(fontsize=ticksize)
    plt.title(f"PDF, prior mean "+r'$M_{t_0}$'+f"={mean_label} " + r'$ma^{-1}$', fontsize=label, pad=20)
    #if idx_mean >= 2:
    plt.legend(fontsize=legend, loc='upper left')

    plt.subplot(2,1,2)
    plt.xlim(0, max_val_un)
    plt.xlabel("SLC (mm)", fontsize=label)
    plt.ylabel("P(S<S*)", fontsize=label)
    plt.xticks(fontsize=ticksize)
    plt.yticks(fontsize=ticksize)
    #if idx_mean >= 2:
    plt.legend(fontsize=legend, loc='upper left')
    plt.title(f"CDF, prior mean "+r'$M_{t_0}$'+f"={mean_label} " + r'$ma^{-1}$', fontsize=label, pad=20)

    plt.tight_layout()
    plt.subplots_adjust(hspace=0.4)  # Increase space between subplots
    plt.savefig(f"./../assets/plots/fig_5/cdf_pdf_{timestamp}_mean_{mean_label}.pdf", bbox_inches='tight')

# Figure 4: this can take > 10 minute to run
timestamp_list = [100,125,150]
max_val_list= [150,200,300,400]
all_data = []
means = [(2.2,10),(2.872,20),(3.52,40)] #(3.285,30),

for idx_mean, (mean_prior, mean_label) in enumerate(means):
    # Run the stuff and store the data
    for timestamp, max_val_un in zip(timestamp_list, max_val_list):
        print("A") 
        data_for_plot = generate_predictions(data, timestamp, max_val_un, sample_size, mean_prior)
        all_data.append((timestamp, max_val_un, data_for_plot))

    # Plot in a different loop
    for timestamp, max_val_un, data_for_plot in all_data:
        create_plots(timestamp, max_val_un, data_for_plot, mean_label, idx_mean)


A
10
A
WEEEE
(104, 1)
(104, 1)
WEEEE
