# Bootstrapping

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from _notebooks import analysis
from importlib import reload
from tqdm import tqdm
import warnings

L_box = 50
mu_factor = 6  # conversion to microns
min_factor = 8  # conversion to minutes

In [2]:
import yaml


def apply_time_filter(df, dt):
    """
    dt : float, in min
    """
    df["ts"] = df["time[hr]"] * 60 // dt
    return df.drop_duplicates(subset=["ts", "rid"], keep="first").reset_index(drop=True)


def _get_mp_type(yfile):
    with open(yfile) as f:
        return yaml.safe_load(f)["substrate"]["kind"]


def _make_title(df):
    tbl = {
        "gamma": r"$\gamma$",
        "R_eq": r"$R_{eq}$",
        "mag_std": r"$\sigma_{MVG}$",
        "add_rate": r"$\tau_{MVG}$",
        "gid": "ID",
    }
    d = dict(df.iloc[0][3:8])
    title = ""
    for key, val in d.items():
        title += tbl[key] + " = " + f"{val}" + "\n"
    return title


def linear_init_pts(xmin, xmax, vmin, vmax, n_pts, s=1, basin_only=False):
    def _around_basin():
        d = 10
        x1 = np.linspace(xmin - d, xmin + d, n_pts)
        x2 = np.linspace(xmax - d, xmax + d, n_pts)
        x = np.append(x1, x2)
        y = s * (vmax - vmin) / (xmax - xmin) * (x - xmin) + vmin
        return np.hstack([x.reshape(-1, 1), y.reshape(-1, 1)])

    if basin_only:
        return _around_basin()

    x = np.linspace(xmin, xmax, n_pts)
    y = s * (vmax - vmin) / (xmax - xmin) * (x - xmin) + vmin
    return np.hstack([x.reshape(-1, 1), y.reshape(-1, 1)])


def F_init_pts(F, bounds):
    y_indx, x_indx = np.where(~np.isnan(F))
    xmin, xmax, vmin, vmax, nbins = bounds

    yy = vmin + y_indx * (vmax - vmin) / nbins
    xx = xmin + x_indx * (xmax - xmin) / nbins
    return np.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])


def init_lattice(F, bounds, buffer=1):
    xmin, xmax, vmin, vmax, nbins = bounds
    X, Y = np.meshgrid(
        np.linspace(xmin + buffer, xmax - buffer, nbins),
        np.linspace(vmin + buffer, vmax - buffer, nbins),
    )
    non_nans = np.argwhere(~np.isnan(F))
    x, y = X[0][non_nans[:, 1]], Y[:, 0][non_nans[:, 0]]
    return X, Y, np.hstack([x.reshape(-1, 1), y.reshape(-1, 1)])


def get_xv_traj_for_gid(gid):
    df_gid_singleRunLong = pd.read_pickle(
        f"../_server/sim_data/corners_only_2/single_run_long/pkls/fulltake_gid{gid}.pkl"
    )
    df_gid_singleRunLong.x *= mu_factor
    df_gid_singleRunLong.y *= mu_factor

    singleRun_xva = analysis.calc_v_a_from_position(
        df_gid_singleRunLong.x, df_gid_singleRunLong["time[hr]"]
    )
    return singleRun_xva[["x", "v"]].to_numpy()


def sample_dataframe(df, size, s=0):
    import time

    rng = np.random.RandomState(seed=s + int(time.time()))
    df_list = [elem for _, elem in df.groupby("rid")]
    return pd.concat(
        [
            df_list[k]
            for k in rng.choice(np.arange(len(df_list)), replace=True, size=size)
        ]
    )


def prep_dir(gid):
    import os

    root = (
        f"../_server/sim_data/corners_only_2/phase_space_bootstrap_samples/grid_{gid}/"
    )
    if not os.path.exists(root):
        os.makedirs(root)

    sub_root = f"../_server/sim_data/corners_only_2/phase_space_bootstrap_samples/grid_{gid}/SDs/"
    if not os.path.exists(sub_root):
        os.makedirs(sub_root)

    stream_root = f"../_server/sim_data/corners_only_2/phase_space_bootstrap_samples/grid_{gid}/streamlines/"
    if not os.path.exists(stream_root):
        os.makedirs(stream_root)
        
    return root, sub_root, stream_root


def pretty_imshow(im, cbar_title, mark_basins=False, save_path=None, **imshow_kwargs):
    plt.figure(figsize=(5, 3.5), dpi=300)

    plt.imshow(im, **imshow_kwargs)

    cbar = plt.colorbar()
    cbar.set_label(cbar_title)
    plt.xlabel(r"$x$ ($\mu$m)")
    plt.ylabel(r"$v$ ($\mu$m/hr)")
    plt.axis("auto")

    if mark_basins:
        plt.vlines(
            [133, 167],
            vmin,
            vmax,
            colors=["orange", "orange"],
            linestyles=["dashed", "dashed"],
        )

    if save_path is not None:
        plt.savefig(save_path)
        plt.close()
    plt.show()


#### Load grid by grid
- Load `fulltake_gid*.pkl`, which is all of the runs for a config

In [11]:
from glob import glob

files = glob("../_server/sim_data/corners_only_2/pkls/fulltake_gid*.pkl")

data = []

for file in files:
    df = pd.read_pickle(file)
    df.x *= mu_factor
    df.y *= mu_factor
    data.append(df)

gid_to_indx = dict((data[k].iloc[0].gid.astype(int), k) for k in range(len(data)))
dt = 0.0015 * 250 * 8 / 60  # hr

print(f"Loaded {len(data)} configurations.")
[
    print(
        f"\t - Grid {df.gid.iloc[0]}: {len(df)} total data points | {len(df.rid.unique())} runs | {len(df[df.rid==0])} data points / run"
    )
    for df in data
]
display(data[0])

Loaded 32 configurations.
	 - Grid 0: 365503 total data points | 768 runs | 475 data points / run
	 - Grid 1: 367744 total data points | 768 runs | 480 data points / run
	 - Grid 10: 342373 total data points | 720 runs | 478 data points / run
	 - Grid 11: 318858 total data points | 672 runs | 479 data points / run
	 - Grid 12: 437760 total data points | 912 runs | 480 data points / run
	 - Grid 13: 437760 total data points | 912 runs | 480 data points / run
	 - Grid 14: 425560 total data points | 893 runs | 476 data points / run
	 - Grid 15: 405026 total data points | 851 runs | 477 data points / run
	 - Grid 16: 343479 total data points | 720 runs | 475 data points / run
	 - Grid 17: 345560 total data points | 720 runs | 480 data points / run
	 - Grid 18: 345594 total data points | 720 runs | 480 data points / run
	 - Grid 19: 345453 total data points | 720 runs | 480 data points / run
	 - Grid 2: 344771 total data points | 720 runs | 477 data points / run
	 - Grid 20: 437760 total da

Unnamed: 0,x,y,vx,vy,time[hr],gamma,R_eq,mag_std,add_rate,gid,rid
0,113.500000,150.000000,0.000000,0.000000,0.00,0.8,2.5,15000.0,3.0,0,460
1,115.137492,150.137737,0.693140,0.446027,0.05,0.8,2.5,15000.0,3.0,0,460
2,115.322311,152.056528,0.150598,0.913096,0.10,0.8,2.5,15000.0,3.0,0,460
3,115.688387,152.477372,0.414260,-0.284334,0.15,0.8,2.5,15000.0,3.0,0,460
4,116.142769,151.285669,0.030631,-0.569939,0.20,0.8,2.5,15000.0,3.0,0,460
...,...,...,...,...,...,...,...,...,...,...,...
466,120.831018,151.935035,0.191622,-0.011869,23.75,0.8,2.5,15000.0,3.0,0,292
467,121.219676,151.392487,0.488644,-0.435206,23.80,0.8,2.5,15000.0,3.0,0,292
468,121.159620,150.808940,-0.163576,-0.056389,23.85,0.8,2.5,15000.0,3.0,0,292
469,120.563663,150.238964,0.607091,-0.363632,23.90,0.8,2.5,15000.0,3.0,0,292


### Bootstrap for one grid

Get `N_sample` bootstrapped samples, each with the original number of runs

In [None]:
gid = 9
root_path, msd_path, stream_root = prep_dir(gid)
df_gid_original = data[gid_to_indx[gid]]
n = df_gid_original.rid.unique().size
N_sample = 25
bootstrapped_dfs = [
    sample_dataframe(df_gid_original, size=n, s=i) for i in range(N_sample)
]

Compute $x$, $v$, and $a$. Bin them.

In [None]:
xva_samples = []

for k, sample_df in tqdm(enumerate(bootstrapped_dfs), total=N_sample):
    # x, v, a values for this entire config
    grid_x_v_a = []

    # compute speed and acc for each run
    for rid, df_rid in sample_df.groupby("rid"):
        x_v_a = analysis.calc_v_a_from_position(df_rid.x, df_rid["time[hr]"])
        x_v_a[sample_df.iloc[0][5:].index] = sample_df.iloc[0][5:]
        grid_x_v_a.append(x_v_a)

    grid_x_v_a = pd.concat(grid_x_v_a)

    # with a 112mu mp, this gives bins of dim 3.5mu x 3.5mu
    nbins = 32
    analysis.get_bin_indices(grid_x_v_a, nbins)

    xva_samples.append(grid_x_v_a)

Compute $F(x, v)$ and $\sigma(x, v)$

In [None]:
F_samples = []
sigma_samples = []
xmin, xmax, vmin, vmax = 200, -200, 200, -200

for k, sample in tqdm(enumerate(xva_samples), total=N_sample):
    F, F_std_err, sigma = analysis.calc_F_sigma(sample, dt, nbins, min_pts=5)
    F_samples.append(F)
    sigma_samples.append(sigma)
    xmin_s, xmax_s, vmin_s, vmax_s = (
        sample.agg(["min", "max"])[["x", "v"]].to_numpy().T.flatten()
    )

    if xmin_s < xmin:
        xmin = xmin_s
    if xmax_s > xmax:
        xmax = xmax_s
    if vmin_s < vmin:
        vmin = vmin_s
    if vmax_s > vmax:
        vmax = vmax_s


Compute $\left[F_s(x, v) - F_{\mathrm{avg}}(x, v)\right]^2$

- Plot each instance
- Average to get MSD and plot its log

In [None]:
F_avg = np.nanmean(F_samples, axis=0)
F_SD = []

for s in range(N_sample):
    F_sqrd_diff = (F_samples[s] - F_avg) ** 2
    F_SD.append(F_sqrd_diff)
    # pretty_imshow(
    #     F_sqrd_diff,
    #     cbar_title=r"$\left(F_s - F_{\mathrm{avg}}\right)^2$",
    #     save_path=msd_path + f"sample_{s}",
    #     mark_basins=True,
    #     origin="lower",
    #     extent=[xmin, xmax, vmin, vmax],
    #     interpolation="bilinear",
    # )


In [None]:
import os

cmd = (
    f"ffmpeg -i {msd_path}/sample_%d.png -b:v 4M -s 600x600 -pix_fmt yuv420p -filter:v 'setpts=2.*PTS' "
    f"{root_path}/SDs.mp4 -y -hide_banner -loglevel fatal"
)
os.system(cmd)

In [None]:
F_MSD = np.nanmean(F_SD, axis=0)
pretty_imshow(
    np.log(F_MSD),
    cbar_title=r"ln$\left(\left\langle\left(F_s - F_{\mathrm{avg}}\right)^2\right\rangle\right)$",
    save_path=root_path + "MSD",
    mark_basins=True,
    origin="lower",
    extent=[xmin, xmax, vmin, vmax],
    interpolation="bilinear",
)

Plot streamlines from each $F_s(x, v)$

In [None]:
plot_title = _make_title(sample)
mp_type = _get_mp_type(f"../configs/IM/grid_id{gid}/simbox.yaml")
plot_title += f"substrate = {mp_type}"
bounds = (xmin, xmax, vmin, vmax, nbins)
title = {"title": plot_title, "size": 20}

for k, F in enumerate(F_samples):
    X, Y, init_pts = init_lattice(F, bounds)

    fig = analysis.F_streamplot(
        F,
        bounds,
        stream_init_pts=init_pts,
        title=title,
        interp="bilinear",
        do_try=False,
        streamplot_kwargs={
            "integration_direction": "forward",
            "color": "black",
            "broken_streamlines": False,
            "density": 1,
            "linewidth": 0.5,
        },
        save_path=stream_root + f"sample_{k}.png",
    )

    # fig.gca().vlines(
    #     [133, 167],
    #     vmin,
    #     vmax,
    #     colors=["orange", "orange"],
    #     linestyles=["dashed", "dashed"],
    # )