In [None]:
import sys
sys.path.insert(0, '..') # in order to be able to import from scripts.py

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scripts.client import VisAVisClient
from scripts.make_protocol import make_protocol
from laptrack import LapTrack
import matplotlib
matplotlib.use("TkAgg")

In [None]:
from subplots_from_axsize import subplots_from_axsize

def plot_result(outfile, result, title=None):
    data = result.states
    
    fig, ax = subplots_from_axsize(
        1, 1, axsize=(20, 8),
        left=0., right=0., bottom=0., top=0.
    )

    data_selected = data.copy() #data[data['seconds'] < 4*60*60].copy()
    data_selected['E'] = data_selected['E'] > 0
    data_selected['I'] = data_selected['I'] > 0
    data_selected['R'] = data_selected['R'] > 0

    img_E = data_selected.groupby(['seconds', 'h'])['E'].mean().unstack().to_numpy().T
    img_I = data_selected.groupby(['seconds', 'h'])['I'].mean().unstack().to_numpy().T
    img_R = data_selected.groupby(['seconds', 'h'])['R'].mean().unstack().to_numpy().T

    img = img_E + img_I

    ax.imshow(
        img,
        cmap='gray',
        origin='lower',
        aspect='auto',
        interpolation='none',
    )

    ax.set_xlabel('time')
    ax.set_axis_off()
    
    if title is not None:
        ax.annotate(
            title, (0.5, 0.9),
            xycoords='axes fraction', va='center', ha='center',
            color='red', fontsize=32
        )

    fig.savefig(outfile)
    plt.close(fig)

In [None]:
PARAMETERS_DEFAULT = {
  "c_rate": 1,
  "e_incr": 1,
  "i_incr": 1,
  "r_incr": 0.0667
}

In [None]:
sim_num = 50
results = []
for channel_height in [7]:

    client = VisAVisClient(
        visavis_bin=f'/home/hombresabio/AI/Waves/visavis-seir/target/bins/vis-a-vis-{channel_height}',
    )

    for interval in [100]:

        protocol_file_path = make_protocol(
            pulse_intervals = [interval, 1500],
            duration=4,
            out_folder='./', #f'./interval-{interval}',
        )
        
        for sim in range(sim_num):
            result = client.run(
                parameters_json=PARAMETERS_DEFAULT,
                protocol_file_path= protocol_file_path,
            )
            results.append(result.states)

        #plot_result(
        #    f"./tkankony_{channel_height}_{interval}.png",
        #    result,
        #    title=f"channel height: {channel_height}, interval: {interval}"
        #)

# Track

In [None]:
def get_activations(img, h, plot = False):
    roll_start = img.iloc[h].rolling(5, center=True, win_type='gaussian', min_periods = 1).mean(std=2) # img indexed by position in channel 
    activations_start = roll_start[(roll_start.diff().fillna(0) >= 0) & (roll_start.diff().fillna(0).shift(-1) < 0)].reset_index()
    if plot:
        plt.scatter(activations_start.seconds, activations_start[h], color = 'red')
        plt.plot(roll_start)
    return activations_start

def get_activations_time(img, t, plot = False, roll = 5):
    roll_start = img.rolling(roll, center=True, win_type='gaussian', min_periods = 1).mean(std=roll/6) # img indexed by position in channel 
    roll_start_diff = roll_start.diff()
    activations_start = roll_start[(roll_start_diff.fillna(np.inf) >= 0) & (roll_start_diff.shift(-1).fillna(-np.inf) < 0)]
    if plot:
        plt.scatter(activations_start.seconds, activations_start[t], color = 'red')
        plt.plot(roll_start)
    return activations_start

def get_infected_img(df):
    img_E = df.groupby(['seconds', 'h'])['E'].sum().unstack().T
    img_I = df.groupby(['seconds', 'h'])['I'].sum().unstack().T
    img = img_E + img_I
    return img

In [None]:
def get_tracks(img, lt, eps = 1e-3):
    activations_times = get_activations_time(img, 0, roll=20)
    df_activations = activations_times.reset_index().melt(
        id_vars=["h"],
        var_name="seconds", 
        value_name="peak",
    ).dropna()
    df_activations['seconds'] = df_activations['seconds'].astype(int)
    df_activations['seconds']//=4
    df_activations['h'] += eps * df_activations['seconds']
    track_df, split_df, merge_df = lt.predict_dataframe(
        df_activations[['h', 'seconds']].sort_values(by='seconds'),
        coordinate_cols=["h"],  
        frame_col="seconds",
        only_coordinate_cols=False,  
    )
    return (track_df, split_df, merge_df)

def plot_tracks(track_df, split_df, merge_df):
    frames = track_df.index.get_level_values("frame")
    frame_range = [frames.min(), frames.max()]
    k1, k2 = "seconds", "h"
    keys = [k1, k2]


    def get_track_end(track_id, first=True):
        df = track_df[track_df["track_id"] == track_id].sort_index(level="frame")
        return df.iloc[0 if first else -1][keys]


    for track_id, grp in track_df.groupby("track_id"):
        df = grp.reset_index().sort_values("frame")
        plt.scatter(df[k1], df[k2], c=df["track_id"], vmin=0, vmax=20, cmap='tab20')#, vmin=frame_range[0], vmax=frame_range[1])
        for i in range(len(df) - 1):
            pos1 = df.iloc[i][keys]
            pos2 = df.iloc[i + 1][keys]
            plt.plot([pos1[0], pos2[0]], [pos1[1], pos2[1]], "-k")
        for _, row in list(split_df.iterrows()) + list(merge_df.iterrows()):
            pos1 = get_track_end(row["parent_track_id"], first=False)
            pos2 = get_track_end(row["child_track_id"], first=True)
            plt.plot([pos1[0], pos2[0]], [pos1[1], pos2[1]], "-k")
    plt.show()

# Analysis

In [None]:
max_distance = 25
lt = LapTrack(
    track_dist_metric="sqeuclidean",  # The similarity metric for particles. See `scipy.spatial.distance.cdist` for allowed values.
    splitting_dist_metric="sqeuclidean",
    merging_dist_metric="sqeuclidean",
    track_cost_cutoff=max_distance**2,
    splitting_cost_cutoff=max_distance**2,  # or False for non-splitting case
    merging_cost_cutoff=False,  # or False for non-merging case
)

In [None]:
df = results[2]
img = get_infected_img(df)

track_df, split_df, merge_df = get_tracks(img, lt)
#activations_df = get_tracks(img, lt)

In [None]:
events = []

def get_events_df(track_df, split_df):
    df = pd.pivot_table(track_df, 'h', ['track_id', 'seconds'])
    front_speed = df['h'].groupby('track_id').diff().groupby('track_id').mean()
    front_speed.name = 'front_speed'
    front_direction = (front_speed > 0) * 2 - 1
    front_direction.name = 'front_direction'
    track_start = track_df.groupby('track_id')['seconds'].min()
    tree_df = track_df.value_counts(['track_id', 'tree_id']).index.to_frame(index=False).set_index('track_id')   
    events_df = split_df.join(front_direction, on='parent_track_id')\
        .join(front_direction, on='child_track_id', lsuffix='_parent', rsuffix='_child')\
            .join(track_start, on='child_track_id')\
                .join(tree_df, on='child_track_id')  
    return events_df

for i, df in enumerate(results):
    img = get_infected_img(df)
    track_df, split_df, merge_df = get_tracks(img, lt)
    if len(split_df) == 0:
        events.append(pd.DataFrame())
        continue
    events_df = get_events_df(track_df, split_df)
    events_df.index.name = 'event_id'
    events.append(events_df)
    continue

In [None]:
df_results = pd.concat(events, names=['sim_id'], keys=range(len(results)))
df_results