In [None]:
import pandas as pd
import numpy as np
import holoviews as hv
from holoviews import opts
import hvplot.pandas
import bokeh
import matplotlib.pyplot as plt
hv.extension('bokeh')
import panel as pn
pn.extension()
from datetime import datetime, timezone
import param

In [None]:
eddy_min_longitude = -16
eddy_max_longitude = -13
eddy_min_latitude = 48
eddy_max_latitude = 50

observations = pd.read_csv('../data/observations.csv')
metadata = pd.read_csv('../data/metadata.csv')
metadata.sample_time = metadata.sample_time.apply(datetime.fromisoformat)
topic_probs = pd.read_csv('../data/topic_probs.csv', dtype=float)
word_probs = pd.read_csv('../data/word_probs.csv')
word_topic_matrix = pd.read_csv('../data/word_topic_matrix.csv')
oceanphysics = pd.read_csv('../data/oceanphysics.csv')
oceanphysics.columns = ['sample_time', 'salinity', 'temperature', 'fluorescence', 'conductivity', 'sigma_t', 'instrument_date']
oceanphysics = oceanphysics.drop('instrument_date', axis=1)
eddycenter = pd.read_csv('../data/eddycenter.csv')
oceanphysics.sample_time = oceanphysics.sample_time.apply(
    lambda x: datetime.strptime(x, '%d-%m-%Y %H:%M:%S').replace(tzinfo=timezone.utc)
)
eddycenter.sample_time = eddycenter.sample_time.apply(
    lambda x: datetime.strptime(x, '%d-%b-%Y %H:%M:%S').replace(tzinfo=timezone.utc)
)
metadata = metadata.set_index('sample_time')
oceanphysics = oceanphysics.set_index('sample_time').reindex(metadata.index, method='nearest')
topic_probs['sample_time'] = metadata.index
word_probs['sample_time'] = metadata.index
observations['sample_time'] = metadata.index
topic_probs = topic_probs.set_index('sample_time')
word_probs = word_probs.set_index('sample_time')
observations = observations.set_index('sample_time')
metadata_oceanphysics = pd.concat([metadata, oceanphysics], axis=1)
observations_norm = observations.div(observations.sum(axis=1), axis=-0)
topic_probs_idxmax = pd.DataFrame({
    'latitude': metadata.latitude,
    'longitude': metadata.longitude,
    'ml_topic': topic_probs.idxmax(axis=1)
})
eddycenter = eddycenter.set_index('sample_time').reindex(metadata.index, method='nearest')


In [None]:
N = len(observations)

eddy_idx = metadata.longitude > eddy_min_longitude
eddy_idx &= metadata.longitude < eddy_max_longitude
eddy_idx &= metadata.latitude > eddy_min_latitude
eddy_idx &= metadata.latitude < eddy_max_latitude

timefmt = '%Y-%m-%d %H:%M:%S+%z'
epoch1_start = datetime.fromisoformat('2021-05-04 00:00:00+00:00')
epoch2_start = datetime.fromisoformat('2021-05-11 00:00:00+00:00')
epoch3_start = datetime.fromisoformat('2021-05-21 00:00:00+00:00')
epoch1_idx = metadata.index < epoch2_start
epoch2_idx = ~epoch1_idx & (metadata.index  < epoch3_start)
epoch3_idx = ~epoch1_idx & ~epoch2_idx

observations_epoch = pd.DataFrame(observations)
observations_epoch["epoch"] = (
    1 * (observations_epoch.index < epoch2_start) +
    2 * ((epoch2_start < observations_epoch.index) & (observations_epoch.index < epoch3_start)) +
    3 * (epoch3_start < observations_epoch.index)
)

In [None]:
hv.extension('matplotlib')
def right_axis_temp_hook(plot, element):
    fig = plot.state
    axis1 = plot.handles["axis"]
    axis2 = axis1.twinx()
    axis2.plot(
        oceanphysics.index,
        oceanphysics.temperature,
        color='black',
    )

def right_axis_salt_hook(plot, element):
    fig = plot.state
    axis1 = plot.handles["axis"]
    axis2 = axis1.twinx()
    axis2.plot(
        oceanphysics.index,
        oceanphysics.salinity,
        color='black',
    )


def epoch_hook(plot, element):
    fig = plot.state
    plt.axvline(epoch2_start, figure=fig, ls='--', c='black')

word_points = {i: list(zip(observations_norm.index, observations_norm[x])) for i, x in enumerate(observations_norm.columns)}
word_prob_overlay = hv.Overlay(
    [
        hv.Area(word_points[i]) .opts(
            linewidth=0,
            color=hv.Cycle("tab20"),
            # aspect=2,
            # fig_inches=15,
            # fig_bounds=(0, 0, 0.8, 0.8),
            # fig_size=400,
            show_legend=True,
            hooks=[epoch_hook]
        ) for i in range(len(observations_norm.columns))
    ]
).opts(hooks=[epoch_hook])
topic_points = {i: list(zip(topic_probs.index, topic_probs[x])) for i, x in enumerate(topic_probs.columns)}
topic_prob_overlay = hv.Overlay(
    [
        hv.Area(topic_points[i]) .opts(
            linewidth=0,
            color=hv.Cycle("tab20"),
            # aspect=2,
            # fig_inches=15,
            # fig_bounds=(0, 0, 0.8, 0.8),
            # fig_size=400,
            show_legend=True,
            hooks=[epoch_hook]
        ) for i in range(len(topic_probs.columns))
    ]
).opts(hooks=[epoch_hook])

In [None]:
hv.extension('matplotlib')

hv.Area.stack(word_prob_overlay).opts(
    # fig_size=400,
    # data_aspect=4,
    xlabel="sample_time",
    ylabel="taxon_proportion",
    legend_position="bottom",
    show_legend=True,
    title="(1) Observations vs time",
    fig_inches=10,
    hooks=[epoch_hook]
)

In [None]:
hv.extension('matplotlib')

hv.Area.stack(topic_prob_overlay).opts(
    # fig_size=400,
    # data_aspect=4,
    xlabel="sample_time",
    ylabel="topic_proportion",
    legend_position="bottom",
    show_legend=True,
    title="(2) Topics vs time",
    fig_inches=10,
    hooks=[epoch_hook],
)

In [None]:
hv.extension('matplotlib')

hv.Area.stack(word_prob_overlay).opts(
    # fig_size=400,
    # data_aspect=4,
    xlabel="sample_time",
    ylabel="taxon_proportion",
    legend_position="bottom",
    show_legend=True,
    title="(3) Observations vs time w/ temperature",
    fig_inches=10,
    hooks=[right_axis_temp_hook, epoch_hook],
)

In [None]:
hv.extension('matplotlib')

hv.Area.stack(topic_prob_overlay).opts(
    # fig_size=400,
    # data_aspect=4,
    xlabel="sample_time",
    ylabel="topic_proportion",
    legend_position="bottom",
    show_legend=True,
    title="(4) Topics vs time w/ temperature",
    fig_inches=10,
    hooks=[right_axis_temp_hook, epoch_hook],
)

In [None]:
hv.extension('matplotlib')

hv.Area.stack(word_prob_overlay).opts(
    # fig_size=400,
    # data_aspect=4,
    xlabel="sample_time",
    ylabel="taxon_proportion",
    legend_position="bottom",
    show_legend=True,
    title="(5) Observations vs time w/ salinity",
    fig_inches=10,
    hooks=[right_axis_salt_hook, epoch_hook],
)

In [None]:
hv.extension('matplotlib')

hv.Area.stack(topic_prob_overlay).opts(
    # fig_size=400,
    # data_aspect=4,
    xlabel="sample_time",
    ylabel="topic_proportion",
    legend_position="bottom",
    show_legend=True,
    title="(6) Topics vs time w/ salinity",
    fig_inches=10,
    hooks=[right_axis_salt_hook, epoch_hook],
)

In [None]:
hv.extension('bokeh')
ranks = observations_epoch[eddy_idx].groupby("epoch").aggregate('sum').rank(axis=1, method='min', ascending=False).reset_index().melt(
    id_vars=['epoch'],
    var_name='taxon',
    value_name="rank",
)
epoch1_ranks = ranks[ranks.epoch == 1].set_index(ranks[ranks.epoch == 1]['rank'].astype(int)).taxon.sort_index().to_list()
epoch1_rank_dim = hv.Dimension('taxon', values=epoch1_ranks)
epoch1_bars = hv.Bars(ranks[ranks.epoch == 1], epoch1_rank_dim, 'rank').opts(
    width=2000,
    height=500,
    xrotation=90,
    show_grid=True,
    show_legend=True,
)

epoch2_bars = hv.Bars(ranks[ranks.epoch == 2], epoch1_rank_dim, 'rank').opts(
    width=2000,
    height=500,
    xrotation=90,
    show_grid=True,
    show_legend=True,

)
(epoch1_bars * epoch2_bars).opts(
    show_legend=True,
    show_grid=True,
    multiple_legends=True,
    title="(7) Rank abundance, epoch 1 and 2")

In [None]:
hv.extension('bokeh')
topic_probs_idxmax[epoch1_idx & eddy_idx].hvplot(
    kind='scatter',
    x='longitude',
    y='latitude',
    by='ml_topic',
    title='(8) Epoch 1 max likelihood topic')

In [None]:
hv.extension('bokeh')
topic_probs_idxmax[epoch2_idx & eddy_idx].hvplot(
    kind='scatter',
    x='longitude',
    y='latitude',
    by='ml_topic',
    title='(9) Epoch 2 max likelihood topic')



In [None]:
hv.extension('bokeh')
word_topic_matrix.hvplot(
    kind='heatmap',
    title='(10) Word-topic matrix',
    rot=90,
    cmap='BuGn',
    width=2000,
    height=600,
    logz=True,
    clim=(1e-6, 1),
    yticks=list(range(8)),
)

In [None]:
hv.extension('bokeh')
topics_eddy = pd.DataFrame(topic_probs)
topics_eddy['r_ec'] = eddycenter['r_ec']
topics_eddy['time_of_day'] = topics_eddy.index.to_series().apply(lambda x: x.hour)
topics_eddy['epoch1'] = epoch1_idx * 1
topics_eddy['epoch2'] = epoch2_idx * 1
topics_eddy = topics_eddy.reset_index().melt(id_vars=['sample_time', 'r_ec', 'time_of_day', 'epoch1', 'epoch2'], var_name='topic', value_name='fraction')
all_topics = [str(x) for x in list(range(8))]
class TopicEddyPlot(param.Parameterized):
    topics = param.ListSelector(default=all_topics, objects=all_topics,)
    time_of_day = param.ListSelector(default=list(range(24)), objects=list(range(24)))
    mute = param.Boolean(default=True)
    epochs = param.ListSelector(default=[1, 2], objects=[1, 2])

    @param.depends('topics', 'time_of_day', 'mute', 'epochs')
    def plot(self):
        sel = topics_eddy.topic.isin(self.topics)
        sel &= topics_eddy.time_of_day.isin(self.time_of_day)
        epoch_idx = (topics_eddy.epoch1 * 0)
        if 1 in self.epochs:
            epoch_idx |= topics_eddy.epoch1
        if 2 in self.epochs:
            epoch_idx |= topics_eddy.epoch2
        sel &= (epoch_idx)
        topics_eddy['alpha'] = 0.25 + 0.75 * (
            sel
        )
        data = topics_eddy if self.mute else topics_eddy[sel]
        ret = hv.Scatter(data, kdims=['r_ec'], vdims=['fraction', 'topic', 'sample_time', 'time_of_day', 'alpha']).opts(
            color='topic',
            width=750,
            height=600,
            cmap='tab10',
            logx=True,
            logy=True,
            xlim=(1.0e-1, 1.0e3),
            ylim=(1.0e-4, 1.0e0)
        )
        if self.mute:
            ret = ret.opts(alpha='alpha')
        return ret
tep = TopicEddyPlot()
controls = pn.Row(
    pn.Column(
        pn.widgets.CheckBoxGroup.from_param(tep.param['epochs'], title='epochs', width=75, height=75),
        pn.widgets.CheckBoxGroup.from_param(tep.param['topics'], title='Topic', width=75, height=250), 
        pn.widgets.Toggle.from_param(tep.param['mute'], title='mute', width=75, height=25)
    ),
    pn.widgets.MultiSelect.from_param(tep.param['time_of_day'], title='Hour', width=75, height=425),

)
pn.Row(controls, tep.plot)

In [None]:
hv.Bars(
    word_topic_matrix.rank(axis=1).astype(int).reset_index().melt(var_name='taxon', value_name='rank', id_vars=['index']),
    kdims=['taxon',], vdims=['rank', 'index']
).opts(
    xrotation=90,
    width=1000,
    height=800,
    stacked=False,
    color='index'
)

In [None]:
hv.help(hv.Bars)