In [1]:
import datetime
import os
import random

import matplotlib.colors as colors
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import matplotlib.text as mpl_text
import numpy as np
import pandas as pd
import plottools.colors as c
from matplotlib.gridspec import GridSpec
from sklearn.neighbors import KernelDensity
import cmocean

import gridtools as gt
from plotstyle import PlotStyle

%matplotlib qt

In [2]:
class AnyObject(object):
    def __init__(self, text, color):
        self.my_text = text
        self.my_color = color


class AnyObjectHandler(object):
    def legend_artist(self, legend, orig_handle, fontsize, handlebox):
        print(orig_handle)
        x0, y0 = handlebox.xdescent, handlebox.ydescent
        width, height = handlebox.width, handlebox.height
        patch = mpl_text.Text(
            x=0,
            y=0,
            text=orig_handle.my_text,
            color=orig_handle.my_color,
            verticalalignment="baseline",
            horizontalalignment="left",
            multialignment=None,
            fontproperties=None,
            rotation=0,
            linespacing=None,
            rotation_mode=None,
        )
        handlebox.add_artist(patch)
        return patch


def data_center2d(x, y, start, stop):
    minx = np.min(x[start:stop])
    miny = np.min(y[start:stop])
    maxx = np.max(x[start:stop])
    maxy = np.max(y[start:stop])

    pos = (minx + (maxx - minx) / 2, miny + (maxy - miny) / 2)
    return pos


def get_datetime(folder):
    folder = folder[:-1]
    rec_year, rec_month, rec_day, rec_time = os.path.split(os.path.split(folder)[-1])[
        -1
    ].split("-")
    rec_year = int(rec_year)
    rec_month = int(rec_month)
    rec_day = int(rec_day)
    try:
        rec_time = [int(rec_time.split("_")[0]), int(rec_time.split("_")[1]), 0]
    except:
        rec_time = [int(rec_time.split(":")[0]), int(rec_time.split(":")[1]), 0]

    rec_datetime = datetime.datetime(
        year=rec_year,
        month=rec_month,
        day=rec_day,
        hour=rec_time[0],
        minute=rec_time[1],
        second=rec_time[2],
    )

    return rec_datetime


def dodge_data_y(data, start, stop):
    def return_longer(array1, array2):
        if len(array1) > len(array2):
            return array1
        elif len(array1) < len(array2):
            return array2
        elif len(array1) == len(array2):
            print("[ dodge_data.return_longer ] It's a tie! Returning random")
            return random.choice([array1, array2])

    ymin, ymax, yevent = np.min(data), np.max(data), np.mean(data[start:stop])
    split1, split2 = np.arange(ymin, yevent), np.arange(yevent, ymax)
    larger = return_longer(split1, split2)
    ypos = np.median(larger)
    return ypos


def clock_time(xlims, rec_datetime, times, axis):
    xlim = xlims
    dx = np.diff(xlim)[0]

    label_idx0 = 0
    if dx <= 20:
        res = 1
    elif dx > 20 and dx <= 120:
        res = 10
    elif dx > 120 and dx <= 1200:
        res = 60
    elif dx > 1200 and dx <= 3600:
        res = 600  # 10 min
    elif dx > 3600 and dx <= 7200:
        res = 1800  # 30 min
    else:
        res = 3600  # 60 min

    if dx > 1200:
        if rec_datetime.minute % int(res / 60) != 0:
            dmin = int(res / 60) - rec_datetime.minute % int(res / 60)
            label_idx0 = dmin * 60

    xtick = np.arange(label_idx0, times[-1], res)
    datetime_xlabels = list(
        map(lambda x: rec_datetime + datetime.timedelta(seconds=x), xtick)
    )

    if dx > 120:
        xlabels = list(
            map(
                lambda x: ("%2s:%2s" % (str(x.hour), str(x.minute))).replace(" ", "0"),
                datetime_xlabels,
            )
        )
        rotation = 0
    else:
        xlabels = list(
            map(
                lambda x: (
                    "%2s:%2s:%2s" % (str(x.hour), str(x.minute), str(x.second))
                ).replace(" ", "0"),
                datetime_xlabels,
            )
        )
        rotation = 45
    # ToDo: create mask
    mask = np.arange(len(xtick))[(xtick > xlim[0]) & (xtick < xlim[1])]
    axis.set_xticks(xtick[mask])
    axis.set_xticklabels(np.array(xlabels)[mask], rotation=rotation)
    axis.set_xlim(xlim)


def log_transform_image(im):
    """returns log(image) scaled to the interval [0,1]"""
    try:
        (min, max) = (im[im > 0].min(), im.max())
        if (max > min) and (max > 0):
            return (np.log(im.clip(min, max)) - np.log(min)) / (
                np.log(max) - np.log(min)
            )
    except:
        pass
    return im


def kde2d(x, y, bandwidth, xbins=100j, ybins=100j, **kwargs):
    """Build 2D kernel density estimate (KDE)."""

    # create grid of sample locations (default: 100x100)
    xx, yy = np.mgrid[0:350:xbins, 0:350:ybins]

    xy_sample = np.vstack([yy.ravel(), xx.ravel()]).T
    xy_train = np.vstack([y, x]).T

    kde_skl = KernelDensity(bandwidth=bandwidth, **kwargs)
    kde_skl.fit(xy_train)

    # score_samples() returns the log-likelihood of the samples
    z = np.exp(kde_skl.score_samples(xy_sample))
    return xx, yy, np.reshape(z, xx.shape)


In [3]:
s = PlotStyle()

datapath = '/home/weygoldt/Data/uni/efish/output/2016-04-09-22_25/'

# data aquisition
rec_datetime = get_datetime(datapath)
grid = gt.GridTracks(datapath)
df = pd.read_csv(datapath + "events.csv")

[93m[1m[ GridTracks.__init__ ][0m No temperature an light data found in directory /home/weygoldt/Data/uni/efish/output/2016-04-09-22_25/.
[93m[1m[ GridTracks.__init__ ][0m No grid metadata found in directory /home/weygoldt/Data/uni/efish/output/2016-04-09-22_25/


In [4]:
# event indices of dataframe to iterate over
indices = [2]

idx = indices[0]

ids = [df.id1[idx], df.id2[idx]]
tstart, tstop = df.start[idx], df.stop[idx]

# for each ID, extract all positions for the KDE
x_grid1 = grid.xpos_smth[grid.ident_v == ids[0]]
y_grid1 = grid.ypos_smth[grid.ident_v == ids[0]]
x_grid2 = grid.xpos_smth[grid.ident_v == ids[1]]
y_grid2 = grid.ypos_smth[grid.ident_v == ids[1]]

# compute kernel density estimates of all positions
xx1, yy1, zz1 = kde2d(x_grid1, y_grid1, 9)
xx2, yy2, zz2 = kde2d(x_grid2, y_grid2, 9)

# Initialize a dyad class instance to get distance and delta f
dyad = gt.Dyad(grid, ids)
start = gt.utils.find_closest(dyad.times, tstart)
stop = gt.utils.find_closest(dyad.times, tstop)

# get dyad track coordinates
x_dyad1 = dyad.xpos_smth_id1
y_dyad1 = dyad.ypos_smth_id1
x_dyad2 = dyad.xpos_smth_id2
y_dyad2 = dyad.ypos_smth_id2


[93m[1m[ Dyad.__init__  ][0m GridTracks instance has no temperature and light arrays.


In [5]:
""" The Gaussian 2d histogram 

fade1 = s.fade_cmap(cmocean.cm.deep)
fade2 = s.fade_cmap(cmocean.cm.deep)
backg = cmocean.cm.deep

# gaussian smooth 2d histogram of fish positions
sigma = 86
im1 = fs.gaussianhist2d(x_gridtracks1, y_gridtracks1, extent, sigma)
im2 = fs.gaussianhist2d(x_gridtracks2, y_gridtracks2, extent, sigma)

# plot into gridspec
vmin, vmax = fs.lims(im1, im2)  # get vmin and max
background = np.full(np.shape(im1), 0)  # make background image

ax_grid.imshow(
    im2,
    origin="lower",
    extent=extent.flatten(),
    cmap=fade2,
    alpha=0.6,
    vmin=vmin,
    vmax=vmax,
    zorder=2,
)
ax_grid.imshow(
    im1,
    origin="lower",
    extent=extent.flatten(),
    cmap=fade1,
    alpha=0.6,
    vmin=vmin,
    vmax=vmax,
    zorder=2,
)
"""

"""The gaussian 2d histogram for all positions
hist = gt.utils.gaussianhist2d(
    grid.xpos_smth, grid.ypos_smth, extent, sigma=32, bins=1000
)
ax_grid.imshow(hist, extent=extent.flatten(), cmap=cmocean.cm.haline, alpha=0.5)
"""

plt.rcParams["axes.titlepad"] = 12

# Grid setup
hr = [1, 1]
wr = [1, 1, 1, 0.6, 0.2]
ny = len(hr)
nx = len(wr)

fig = plt.figure(figsize=(300 * s.mm, 130 * s.mm), constrained_layout=False)

# init gridspec
gs = GridSpec(ny, nx, figure=fig, height_ratios=hr, width_ratios=wr)
gs.update(left=0.01, right=0.99, bottom=0.1, top=0.91, wspace=0.3, hspace=0.2)

# make axes
ax_grid = plt.subplot(gs[0:2, 0:2])
ax_grid.set_aspect("equal")
ax_dist = plt.subplot(gs[0, 2:])
ax_freq = plt.subplot(gs[1, 2:], sharex=ax_dist)
plt.setp(ax_dist.get_xticklabels(), visible=False)

# make axis titles
#s.fancy_title(ax_grid, "Two fish on electrode grid")
#s.fancy_title(ax_dist, "Fish distance and frequency over time")

ax_grid.set_title("Two fish on electrode grid", loc = "center")
ax_dist.set_title("Fish distance and frequency over time", loc = "center")

# make axis labels
ax_grid.set_xlabel("x position [cm]")
ax_grid.set_ylabel("y position [cm]")

# make axes limits
extent = np.array([[-10, 360], [-10, 360]])

ax_dist.set_ylabel("distance [cm]")
ax_dist.set(xticklabels=[])
ax_freq.set_xlabel("time [hh:mm]")
ax_freq.set_ylabel("frequency [Hz]")

# create grid coordinates
gridx = []
gridy = []

x_constructor = np.linspace(0, 350, 8)
for x_coord in x_constructor:
    y_constructor = np.ones(8) * x_coord
    gridx.extend(x_constructor)
    gridy.extend(y_constructor)

ax_grid.scatter(gridx, gridy, **s.grid_electrodes)
ax_grid.set_xlim(extent[0])
ax_grid.set_ylim(extent[1])

# log levels for contour plots
levels = np.geomspace(1 * 10 ** (-6), 1 * 10 ** (-2), 6)

# plot contour plots
ax_grid.contourf(
    xx1, yy1, zz1, levels=levels, extent=extent.flatten(), **s.kde1_shading1
)

ax_grid.contourf(
    xx2, yy2, zz2, levels=levels, extent=extent.flatten(), **s.kde2_shading1
)

ax_grid.contour(
    xx1, yy1, zz1, levels=levels, extent=extent.flatten(), **s.kde1_contours
)

ax_grid.contour(
    xx2, yy2, zz2, levels=levels, extent=extent.flatten(), **s.kde2_contours
)

# plot position during interaction
ax_grid.plot(
    x_dyad1[start:stop], y_dyad1[start:stop], **s.id1, label="position ID 1"
)

ax_grid.plot(
    x_dyad2[start:stop], y_dyad2[start:stop], **s.id2, label="position ID 2"
)

# plot relative distances
ax_dist.plot(dyad.times, dyad.dpos, **s.distance, label="fish distance")

ax_dist.axvspan(dyad.times[start], dyad.times[stop], **s.timewindow)

lims = 0, np.max(dyad.dpos) + 10
ax_dist.set_ylim(lims[0], lims[1])

# plot fundamental frequencies
ax_freq.plot(dyad.times, dyad.fund_id1, **s.id1, label="fund. freq. ID 1")
ax_freq.plot(dyad.times, dyad.fund_id2, **s.id2, label="fund. freq. ID 2")
ax_freq.axvspan(dyad.times[start], dyad.times[stop], **s.timewindow)

lims = gt.utils.lims(dyad.fund_id1, dyad.fund_id2)
ax_freq.set_ylim(lims[0] - 2, lims[1] + 2)
xlims = (np.min(dyad.times), np.max(dyad.times))
clock_time(xlims, rec_datetime, dyad.times, ax_freq)

# plot positions during non-interaction
a, b = 5, 160
random_time_padding = random.randint(a, b) * 60 * 3  # minutes * seconds * trackrate
plot_random = True  # control parameter to plot or not
first = True  # control parameter for label sequence
if len(x_dyad2[:start]) > random_time_padding:
    randstart = start - random_time_padding
    randstop = stop - random_time_padding
elif len(x_dyad2[stop:]) > random_time_padding:
    randstart = start + random_time_padding
    randstop = stop + random_time_padding
    first = False
else:
    plot_random = False

if plot_random:
    if first:
        eventnum = 2
        randnum = 1
    elif first == False:
        eventnum = 1
        randnum = 2
    else:
        raise Exception("Plot label could not be computed!")

    # override randomness with nice position:
    randstart, randstop = 10356, 10765
    eventnum, randnum = 2,1  

    # plot random position during interaction
    ax_grid.plot(x_dyad1[randstart:randstop], y_dyad1[randstart:randstop], **s.id1)

    ax_grid.plot(x_dyad2[randstart:randstop], y_dyad2[randstart:randstop], **s.id2)

    # annotate 2d kde random positions
    # calculate center position of each track plot to position annotation

    pos1 = data_center2d(x_dyad1, y_dyad1, randstart, randstop)
    pos2 = data_center2d(x_dyad2, y_dyad2, randstart, randstop)
    s.circled_annotation(randnum, ax_grid, pos1[0], pos1[1])
    s.circled_annotation(randnum, ax_grid, pos2[0], pos2[1])

    # annotate interaction in random relative distances
    ax_dist.axvspan(dyad.times[randstart], dyad.times[randstop], **s.timewindow)

    # annotate random frequency tracks
    ax_freq.axvspan(dyad.times[randstart], dyad.times[randstop], **s.timewindow)

    # annotate random position in distance plot
    xpos = dyad.times[gt.utils.get_midpoint(randstart, randstop)]
    ypos = dodge_data_y(dyad.dpos, randstart, randstop)
    s.circled_annotation(randnum, ax_dist, xpos, ypos)

# annotate day and night
#ax_dist.axvspan(dyad.times[32400], dyad.times[-1], color = 'orange', lw = 0, alpha = 0.2)
#ax_freq.axvspan(dyad.times[32400], dyad.times[-1], color = 'orange', lw = 0, alpha = 0.2)

ax_dist.axvspan(dyad.times[0], dyad.times[32600], color = 'gray', lw = 0, alpha = 0.16)
ax_freq.axvspan(dyad.times[0], dyad.times[32600], color = 'gray', lw = 0, alpha = 0.16)


# annotate interaction position in 2d kde plot
annotloc_x = np.append(x_dyad1[start:stop], x_dyad2[start:stop])
annotloc_y = np.append(y_dyad1[start:stop], y_dyad2[start:stop])
pos = data_center2d(x_dyad1, y_dyad2, start, stop)
s.circled_annotation(eventnum, ax_grid, pos[0], pos[1])

# annotate interation position in distance plot
xpos = dyad.times[gt.utils.get_midpoint(start, stop)]
ypos = dodge_data_y(dyad.dpos, start, stop)
s.circled_annotation(eventnum, ax_dist, xpos, ypos)

# annotate interaction duration
duration = dyad.times[stop] - dyad.times[start]
dur_str = gt.utils.strfdelta(duration, fmt="{M:02}m {S:02}s", inputtype="s")

props = dict(
    facecolor="white", edgecolor="none", boxstyle="round,pad=0.2", alpha=0.6
)
ax_grid.text(
    0.025,
    0.065,
    f"duration synchronized: {dur_str}",
    transform=ax_grid.transAxes,
    verticalalignment="top",
    bbox=props,
    zorder=1000,
)

ax_dist.set_ylim(0,350)
ax_freq.set_ylim(860,900)

# ax_grid.spines['right'].set_visible(False)
# ax_grid.spines['top'].set_visible(False)
# 
# ax_freq.spines['right'].set_visible(False)
# ax_freq.spines['top'].set_visible(False)
# 
# ax_dist.spines['right'].set_visible(False)
# ax_dist.spines['top'].set_visible(False)

# # add legend to heatpmap
# handles, labels = ax_grid.get_legend_handles_labels()
# kde1_patch = mpatches.Patch(
#     label="position kde ID 1",
#     color=s.kde1_shading["colors"],
#     alpha=s.kde1_shading["alpha"] * 2,
# )

# kde2_patch = mpatches.Patch(
#     label="position kde ID 2",
#     color=s.kde2_shading["colors"],
#     alpha=s.kde2_shading["alpha"] * 2,
# )
# handles.extend([kde1_patch, kde2_patch])
# ax_grid.legend(
#     handles=handles, bbox_to_anchor=(0.5, -0.39), loc="lower center", ncol=2
# )

# add legend to position and fund plots
# eventnum_handle = AnyObject(f" ({eventnum}) ", "black")
# randnum_handle = AnyObject(f" ({randnum}) ", "black")
# 
# handles_freq, labels_freq = ax_freq.get_legend_handles_labels()
# handles_pos, labels_pos = ax_dist.get_legend_handles_labels()

# handles = [
#     handles_freq[0],
#     eventnum_handle,
#     handles_freq[1],
#     randnum_handle,
#     handles_pos[0],
# ]

# labels = [
#     labels_freq[0],
#     "sync. modulation",
#     labels_freq[1],
#     "random time",
#     labels_pos[0],
# ]

# ax_freq.legend(
#     handles=handles,
#     labels=labels,
#     handler_map={
#         eventnum_handle: AnyObjectHandler(),
#         randnum_handle: AnyObjectHandler(),
#     },
#     bbox_to_anchor=(0.5, -0.8),
#     loc="lower center",
#     ncol=3,
# )

fig.align_labels()

plt.savefig(grid.datapath + f"eventposition_{idx}.pdf", bbox_inches="tight")
plt.show()

