In [None]:
import sys
sys.path.insert(0, '/home/hombresabio/AI/Waves/visavis-seir') # in order to be able to import from scripts.py

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from scripts.client import VisAVisClient
from scripts.make_protocol import make_protocol

In [None]:
from subplots_from_axsize import subplots_from_axsize

def plot_result(outfile, result, title=None):
    data = result.states
    
    fig, ax = subplots_from_axsize(
        1, 1, axsize=(20, 8),
        left=0., right=0., bottom=0., top=0.
    )

    data_selected = data.copy() #data[data['seconds'] < 4*60*60].copy()
    data_selected['E'] = data_selected['E'] > 0
    data_selected['I'] = data_selected['I'] > 0
    data_selected['R'] = data_selected['R'] > 0

    img_E = data_selected.groupby(['seconds', 'h'])['E'].mean().unstack().to_numpy().T
    img_I = data_selected.groupby(['seconds', 'h'])['I'].mean().unstack().to_numpy().T
    img_R = data_selected.groupby(['seconds', 'h'])['R'].mean().unstack().to_numpy().T

    img = img_E + img_I

    ax.imshow(
        img,
        cmap='gray',
        origin='lower',
        aspect='auto',
        interpolation='none',
    )

    ax.set_xlabel('time')
    ax.set_axis_off()
    
    if title is not None:
        ax.annotate(
            title, (0.5, 0.9),
            xycoords='axes fraction', va='center', ha='center',
            color='red', fontsize=32
        )

    fig.savefig(outfile)
    plt.close(fig)

In [None]:
PARAMETERS_DEFAULT = {
  "c_rate": 1,
  "e_incr": 1,
  "i_incr": 1,
  "r_incr": 0.0667
}

In [None]:
sim_num = 100
results = []
for channel_height in [7]:

    client = VisAVisClient(
        visavis_bin=f'/home/hombresabio/AI/Waves/visavis-seir/target/bins/vis-a-vis-{channel_height}',
    )

    for interval in [100]:

        protocol_file_path = make_protocol(
            pulse_intervals = [interval, 1500],
            duration=4,
            out_folder='./', #f'./interval-{interval}',
        )
        
        for sim in range(sim_num):
            result = client.run(
                parameters_json=PARAMETERS_DEFAULT,
                protocol_file_path= protocol_file_path,
            )
            results.append(result.states)

        #plot_result(
        #    f"./tkankony_{channel_height}_{interval}.png",
        #    result,
        #    title=f"channel height: {channel_height}, interval: {interval}"
        #)

# Visualization

In [None]:
def get_activations(img, h, plot = False):
    roll_start = img.iloc[h].rolling(5, center=True, win_type='gaussian', min_periods = 1).mean(std=2) # img indexed by position in channel 
    activations_start = roll_start[(roll_start.diff().fillna(0) >= 0) & (roll_start.diff().fillna(0).shift(-1) < 0)].reset_index()
    if plot:
        plt.scatter(activations_start.seconds, activations_start[h], color = 'red')
        plt.plot(roll_start)
    return activations_start

def get_activations_time(img, t, plot = False, roll = 10):
    roll_start = img.T.loc[t].rolling(roll, center=True, win_type='gaussian', min_periods = 1).mean(std=roll/6) # img indexed by position in channel 
    activations_start = roll_start[(roll_start.diff().fillna(0) >= 0) & (roll_start.diff().fillna(0).shift(-1) < 0)].reset_index()
    if plot:
        plt.scatter(activations_start.seconds, activations_start[t], color = 'red')
        plt.plot(roll_start)
    return activations_start

def get_infected_img(df):
    img_E = df.groupby(['seconds', 'h'])['E'].sum().unstack().T
    img_I = df.groupby(['seconds', 'h'])['I'].sum().unstack().T
    img = img_E + img_I
    return img

def track(prev, prev_indices, curr, used_indices, threshold=2):
     prev_iterator = zip(prev, prev_indices )
     def next_iteration():
         try:
             return next(prev_iterator)
         except:
             return np.inf, np.inf

     curr_indices = []

     while True:
         if x < y - threshold:
             x, x_ind = next_iteration()
         if x > y + threshold:
             curr_indices.append(used_indices)
             used_indices += 1
             try:
                 y = next(curr)
             except:
                 break
         else:
             curr_indices.append(x_ind)
             x,x_ind = next(prev_iterator)
             try:
                 y = next(curr)
             except:
                 break

In [None]:
from laptrack import LapTrack
max_distance = 25
lt = LapTrack(
    track_dist_metric="sqeuclidean",  # The similarity metric for particles. See `scipy.spatial.distance.cdist` for allowed values.
    splitting_dist_metric="sqeuclidean",
    merging_dist_metric="sqeuclidean",
    # the square of the cutoff distance for the "sqeuclidean" metric
    track_cost_cutoff=max_distance**2,
    splitting_cost_cutoff=max_distance**2,  # or False for non-splitting case
    merging_cost_cutoff=max_distance**2,  # or False for non-merging case
)


In [None]:
df = results[1]
img = get_infected_img(df)
df_spots = pd.concat([get_activations_time(img, t, roll=20).h for t in img.columns], names=['seconds'], keys=img.columns).reset_index()

In [None]:
df_spots['seconds']//=4

In [None]:
track_df, split_df, merge_df = lt.predict_dataframe(
    df_spots[['h', 'seconds']].sort_values(by='seconds'),
    coordinate_cols=[
        "h",
    ],  # the column names for the coordinates
    frame_col="seconds",  # the column name for the frame (default "frame")
    only_coordinate_cols=False,  
    # if False, returned track_df includes columns not in coordinate_cols.
    # False will be the default in the major release.
)

In [None]:
df_spots

In [None]:
df_spots.groupby(['seconds']).size().value_counts()

In [None]:
import matplotlib
matplotlib.use("TkAgg")

plt.figure(figsize=(3, 3))
frames = track_df.index.get_level_values("frame")
frame_range = [frames.min(), frames.max()]
k1, k2 = "seconds", "h"
keys = [k1, k2]


def get_track_end(track_id, first=True):
    df = track_df[track_df["track_id"] == track_id].sort_index(level="frame")
    return df.iloc[0 if first else -1][keys]


for track_id, grp in track_df.groupby("track_id"):
    df = grp.reset_index().sort_values("frame")
    plt.scatter(df[k1], df[k2], c=df["track_id"], vmin=0, vmax=20, cmap='tab20')#, vmin=frame_range[0], vmax=frame_range[1])
    for i in range(len(df) - 1):
        pos1 = df.iloc[i][keys]
        pos2 = df.iloc[i + 1][keys]
        plt.plot([pos1[0], pos2[0]], [pos1[1], pos2[1]], "-k")
    for _, row in list(split_df.iterrows()) + list(merge_df.iterrows()):
        pos1 = get_track_end(row["parent_track_id"], first=False)
        pos2 = get_track_end(row["child_track_id"], first=True)
        plt.plot([pos1[0], pos2[0]], [pos1[1], pos2[1]], "-k")


# plt.xticks([])
# plt.yticks([])
plt.show()
# plt.xlim(300/4, 350/4)
# plt.ylim(200/4, 260/4)

In [None]:
track_df[['track_id', 'tree_id']].drop_duplicates()

In [None]:
plt.set_cmap('rainbow')
for t in range(232,260,4):

    roll_start = img[t].rolling(5, center=True, win_type='gaussian', min_periods = 1).mean(std=2) # img indexed by position in channel 

    plt.plot(roll_start)
plt.xlim(20,100)

In [None]:
ax_1 = plt.gca()
ax_1.imshow(
    img.to_numpy(),
    cmap='gray',
    origin='lower',
    aspect='auto',
    interpolation='none',
)
ax_1.set_xlabel('time')
ax_1.set_axis_off()

In [None]:
#activations_start = activations_start[((activations_start.seconds - activations_start.seconds.shift(1) > 40)) | (activations_start.seconds.shift(1).isna())]

cols = 4
rows = 20
plt.figure(figsize=(4* cols, 4 * rows))


for i, df in enumerate(results[:rows]):
    ax_1 = plt.subplot(rows, cols, 4 * i + 1)
    img = get_infected_img(df)
    h_max = df['h'].max()
    h_min = df['h'].min() + 2 # at h = 0 we have empty cells blocking the signal
    
    ax_1.imshow(
        img.to_numpy(),
        cmap='gray',
        origin='lower',
        aspect='auto',
        interpolation='none',
    )
    ax_1.set_xlabel('time')
    ax_1.set_axis_off()
    
    ax_2 = plt.subplot(rows, cols, 4 * i + 2)
    activ_start = get_activations(img, h_min, plot = True)
    ax_2.set_title(f"Activations at start: {len(activ_start)}")
    
    ax_3 = plt.subplot(rows, cols, 4 * i + 3)
    activ_end = get_activations(img, h_max, plot = True)
    ax_3.set_title(f"Activations at end: {len(activ_end)}")

    ax_4 = plt.subplot(rows, cols, 4 * i + 4)
    roll_total_infected = img.sum(axis=0).rolling(10).mean() # sum over all positions at given moment
    plt.plot(roll_total_infected)  
    ax_4.set_title(f"Infected cells moving average")


todo: Co dzieje się z drugim pulsem:
1. dociera
2. znika
3. rodzi (co i ile)<br>
a) Kiedy <br>
b) Jak daleko od 1 pulsu

# Analysis

In [None]:
activations = []
stats = []
for i, df in enumerate(results):
    img = get_infected_img(df)
    h_max = df['h'].max()
    h_min = df['h'].min() + 2 

    activations_start = get_activations(img, h_min)
    times_start = activations_start.seconds
    activations.extend([(time, 'start') for time in times_start])

    activations_end = get_activations(img, h_max)
    times_end = activations_end.seconds
    activations.extend([(time, 'finish') for time in times_end])

    if len(times_end) < 2:
        times_end = np.append(times_end,[np.nan, np.nan])

    stats.append((len(activations_start), len(activations_end), times_end[0], times_end[1]))

In [None]:
df_activations = pd.DataFrame(activations, columns=['time', 'type'])
df_stats = pd.DataFrame(stats, columns = ['Start Activations', 'Finish Activations', 'First Finish', 'Second Finish'])

In [None]:
plt.title("Time of arrival of signal at the end of a cannal")
sns.histplot(df_activations[df_activations.type == "finish"], bins = 100)

plt.axvline(x=df_stats['First Finish'].min(), color = 'g')
plt.axvline(x=df_stats['First Finish'].max(), color = 'g')

plt.axvline(x=df_stats['Second Finish'].min(), color = 'red')
plt.axvline(x=df_stats['Second Finish'].max(), color = 'red')

#plt.axvline(x=920, color = 'pink')
#df_stats.loc[df_stats['Second Finish'] > 920, 'Second Finish'] = np.nan

In [None]:
p = df_stats['Second Finish'].value_counts().sum() / len(df_stats)
print(f"1. P dotarcia: {p:.2f}\n2. P zaniku/kolizji: {1-p:.2f}")

In [None]:
df_second_fine = df_stats[(~df_stats['Second Finish'].isna()) & (~df_stats['First Finish'].isna())]

In [None]:
plt.figure(figsize=(12, 4)) 
plt.subplot(1, 3, 1).set_title("Additional fronts - start line", fontsize = 10)
sns.histplot(df_second_fine['Start Activations'] - 2, stat='probability')
plt.subplot(1, 3, 2).set_title("Additional fronts - finish line", fontsize = 10)
sns.histplot(df_second_fine['Finish Activations'] - 2, stat='probability')
plt.subplot(1, 3, 3).set_title("Additional fronts - sum", fontsize = 10)
sns.histplot(df_second_fine['Start Activations'] + df_second_fine['Finish Activations'] - 4, stat='probability')
