# This notebook creates a kind of event view for camera pictures and the pmt rates. 

In [None]:
%matplotlib notebook

# This examples shows who to download files from the ONC server
import glob
import os

import pandas as pd
import numpy as np
import datetime
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

import strawb
import strawb.sensors.pmtspec
import strawb.tools

import plotly
import plotly.graph_objs as go
import plotly.express as px
import plotly.io as pio
pio.templates.default = "plotly_white"

import random

plt.rcParams.update({"text.usetex": False})  # fix time fmt of x_ticks

In [None]:
# select the time of interest
timestamp = np.datetime64('2021-09-04T20:42:58')  # wide peak, no image
#timestamp = np.datetime64('2021-09-04T23:44:09')  # blinking
#timestamp = np.datetime64('2021-09-24T21:57')  # over-saturate 'no filter' channel

# Prepare the files

In [None]:
# in case execute db.load_entire_db_from_ONC() to load the entire db
db = strawb.SyncDBHandler()  # loads the db
db.load_onc_db_update(output=True, save_db=True)

In [None]:
# mask by device
mask = db.dataframe.deviceCode == 'TUMPMTSPECTROMETER001'
# mask &= db.dataframe.dataProductCode == 'PMTSD'

# and add mask for synced files only
# mask &= db.dataframe.synced

# and add mask by 'measurement_type'
# mask &= db.dataframe.file_version > 0

## select 1h frame around one bioluminescence event
# timestamp = np.datetime64('2021-09-04T23:44:09')
mask &= db.dataframe.dateFrom >= pd.Timestamp(np.datetime64(timestamp, "h"), tz="UTC")
mask &= db.dataframe.dateFrom < pd.Timestamp(np.datetime64(timestamp, "h"), tz="UTC") + np.timedelta64(1, "h")

### show selected file from the DB - here its 1 file only
db.dataframe[mask]

## In case not all files are synced, do it here

In [None]:
### this will download the selected file
if not db.dataframe.synced[mask].all():
    db.update_db_and_load_files(
        db.dataframe[mask],
        output=True,  # print output to console
        download=True,  # download the files
        #save_db=True
    )
    db.save_db()

## Load the PMTSpec and Camera file

In [None]:
# select the PMTSpec file(s) -> dataProductCode == 'PMTSD'
item = db.dataframe[mask & (db.dataframe.dataProductCode == 'PMTSD')]

pmtspec = strawb.sensors.PMTSpec(file=item.fullPath[0])
pmtspec.file_handler.file_attributes  # show the file attributes, time in seconds since epoch (1.1.1970) UTC

In [None]:
# select the Camera file(s) -> dataProductCode == 'MSSCD'
item = db.dataframe[mask & (db.dataframe.dataProductCode == 'MSSCD')]

camera = strawb.sensors.Camera(file=item.fullPath[0])
camera.file_handler.file_attributes  # show the file attributes, time in seconds since epoch (1.1.1970) UTC

# prepare colors

In [None]:
# prepare colors
### use a nice colormap, or use wavelength_to_rgb for conversion
pmtspec.pmt_meta_data.add_colors(
        plt.cm.get_cmap('viridis'),
        None,
        plt.cm.get_cmap("gist_heat") # reddish color
    )
pmtspec.pmt_meta_data.channel_meta_array

# Class which has data for plot

In [None]:
class PlotData:
    def __init__(self, pmtspec, camera, index=27, 
                 #timestamp=np.datetime64('2021-09-04T23:44:09')
                ):
        #self.timestamp = timestamp
        self.pmtspec = pmtspec
        self.camera = camera
        
        # get Camera image
        # get the closest frame of the camera in time
        self.index = index
        self.timestamp = camera.file_handler.time.asdatetime()[self.index]
        self.images_rgb = camera.images.load_rgb(None, subtract_dark=True)[index]
        self.cam_time, self.cam_capture = self.fill_where_exposed(camera, self.index)
        
        # select the period of interest and mask data
        ### select a couple of seconds around bioluminescence events
        ## you can adjust this time for zoom-in/out
        exposrue_time = np.timedelta64(int(camera.file_handler.exposure_time[self.index]*1e3), "ms")
        start_time = self.timestamp + np.timedelta64(2, "s") # - np.timedelta64(2, "m")
        end_time = self.timestamp + np.timedelta64(2, "s") + exposrue_time
        
        # get PMTSpec data
        # raw data, time, mask for the selected time, rate
        self.raw_time = pmtspec.trb_rates.time_middle
        self.raw_mask = (self.raw_time >= start_time) & (self.raw_time <= end_time)
        self.raw_rate = pmtspec.trb_rates.rate

        # interpolated data, time, mask for the selected time, rate
        pmtspec.trb_rates.interp_frequency = 5
        self.interp_time = pmtspec.trb_rates.interp_time
        self.interp_mask = (self.interp_time >= start_time) & (self.interp_time <= end_time)
        self.interp_rate = pmtspec.trb_rates.interp_rate

        # interpolated data, time, mask for the selected time, rate
        pmtspec.trb_rates.interp_frequency = 2
        self.interp_time_slow = pmtspec.trb_rates.interp_time
        self.interp_rate_slow = pmtspec.trb_rates.interp_rate
        
    # Camera fill where exposed
    @staticmethod
    def fill_where_exposed(camera, index=None):
        """Create an array which is needed to fill the areas where the camera exposed images.

        plt.fill_between(strawb.tools.asdatetime(time), 0, 1, 
                         where=capture>0,  # or >1 for the selected image
                         color='gray', 
                         alpha=0.2, 
                         label='exposure images',
                         transform=plt.gca().get_xaxis_transform())

        PARAMETER
        ---------
        camera: strawb.sensor.Camera
        index: int, optional
            to maks a selected frame
        RETURN
        ------
        time: ndarray, float
            time since epoch in seconds
        capture:
            the array which indicates 1 if a image was during exposure or 0 otherwise.
        """

        start_time = camera.file_handler.time[:]
        stop_time = camera.file_handler.time[:] + camera.file_handler.exposure_time


        # points for [dis-, en-, en-, dis-]-able
        time = np.array((start_time, start_time, stop_time, stop_time)).T
        capture = np.zeros_like(time, dtype=int)
        capture[:,1:3] = 1

        if index is not None:
            capture[index, 1:3] = 2  # the selected frame

        return time.flatten(), capture.flatten()

# get the index of the frame from the timestamp

In [None]:
index = np.argmin(np.abs(camera.file_handler.time.asdatetime()[:] - timestamp))

# The plot includes the indexes of the other frames. 
# Replace it here and run the code again to select the frame
#index = 25
print(f'Take picture with index: {index} at {camera.file_handler.time.asdatetime()[index]}')
plot_data = PlotData(pmtspec=pmtspec, 
                     camera=camera, 
                     index=index, #  timestamp=np.datetime64('2021-09-04T23:40:49')
                    )

# Plot

In [None]:
with_timeline = True

nrows = 2
height_ratios = [2, 1]
    
if with_timeline:
    figsize = np.array([14., 10]) / 2.54  # cm -> inc
    nrows = 2
    height_ratios = [2, 1]
else:
    figsize = np.array([14., 6.6]) / 2.54  # cm -> inc
    nrows = 1
    height_ratios = [2,]

# create the figure
fig = plt.figure(figsize=figsize)

gs = plt.GridSpec(nrows, 2, 
                  width_ratios=[1,2.4], 
                  height_ratios=height_ratios
                  )

ax = [plt.subplot(gs[0, 0]), plt.subplot(gs[0, 1])]
      
if with_timeline:
    ax_row_1 = plt.subplot(gs[1,:])
    ax.append(ax_row_1, 
              #plt.subplot(gs[2,:], sharex = ax_row_1),
             )

### The image
# ax[0].set_title('Image of Event', y=-0.1)
ax[0].set_xlabel('Image of Event')
# ax[0].axis('off')
ax[0].axes.xaxis.set_ticks([])
ax[0].axes.yaxis.set_ticks([])
ax[0].imshow(plot_data.images_rgb/ 2 ** 16)  # images_rgb is a array, select the first frame

### The PMTs
# select channels to plot
mask = np.zeros_like(pmtspec.pmt_meta_data.channel_meta_array['wavelength'], dtype=bool)
for i in [0, 400, 450, 480, 492, 510, 550]:
    mask |= pmtspec.pmt_meta_data.channel_meta_array['wavelength'] == i

# plt.title(timestamp)
for ch in pmtspec.pmt_meta_data.channel_meta_array[mask]:
    _ = ax[1].plot(
        plot_data.interp_time[plot_data.interp_mask],
        plot_data.interp_rate[ch["index"], plot_data.interp_mask], 
        color = ch["color"], #wavelength_to_rgb(ch, gamma=0.4),
        label = ch["label"],
        #marker="o", lw=0, alpha=0.3, ms=5
    )
    
    if with_timeline:
        _ = ax[2].plot(
            plot_data.interp_time_slow,
            plot_data.interp_rate_slow[ch["index"]], 
            color = ch["color"], #wavelength_to_rgb(ch, gamma=0.4),
    #         label = ch["label"],
            #marker="o", lw=0, alpha=0.3, ms=5
        )
        
    ### plot the RAW Data as points
#     _ = ax[1].plot(
#         plot_data.raw_time[plot_data.raw_mask],
#         plot_data.raw_rate[ch["index"], plot_data.raw_mask], 
#         color = ch["color"], #wavelength_to_rgb(ch, gamma=0.4),
#         # label = ch["label"],
#         #alpha=0.3,
#         marker="o", lw=0, alpha=0.01, ms=.5
#     )

# # to show readout frequency 
# ax[3].plot(
#     strawb.tools.asdatetime((pmtspec.file_handler.counts_time[:-1]+pmtspec.file_handler.counts_time[1:])*.5),
#     1./np.diff(pmtspec.file_handler.counts_time[:]), 
#     color = 'gray', #wavelength_to_rgb(ch, gamma=0.4),
# #         label = ch["label"],
#     # marker="o", lw=0, alpha=0.3, ms=5
#     )

# ax[3].vlines(plot_data.interp_time_slow.data[plot_data.interp_time_slow.mask],
#            ymin=0, ymax=1,
#             transform=ax[3].get_xaxis_transform(),
#              color = 'steelblue',
#            )

# Camera
if with_timeline:
    ax[2].fill_between(strawb.tools.asdatetime(plot_data.cam_time), 0, 1, 
                       where=plot_data.cam_capture>0,
                       color='gray', 
                       alpha=0.2, 
                       label='exposure images',
                       transform=ax[2].get_xaxis_transform())
    ax[2].fill_between(strawb.tools.asdatetime(plot_data.cam_time), 0, 1, 
                       where=plot_data.cam_capture>1,
                       color='steelblue', 
                       alpha=0.5, 
                       label='selected image',
                       transform=ax[2].get_xaxis_transform())

    # print the index to the plots
    # the x coords of this transformation are data, and the
    # y coord are axes
    import matplotlib.transforms as transforms
    trans = transforms.blended_transform_factory(ax[2].transData, ax[2].transAxes)
    for i, t_i in enumerate(strawb.tools.asdatetime((plot_data.cam_time[::4]+plot_data.cam_time[2::4])*.5)):
        ax[2].text(t_i, .5, i, fontsize=10, horizontalalignment='center',
                   transform=trans,
                 )
        
    ax[2].set_ylabel('PMT Rates [Hz]')
    ax[2].legend(loc=1)
    ax[2].set_xlim((plot_data.interp_time_slow[0],
                    plot_data.interp_time_slow[-1]))
    ax[2].grid()

    ax[2].xaxis.set_major_formatter(
        mdates.ConciseDateFormatter(ax[2].xaxis.get_major_locator())
    )
    
    ax[2].set_yscale('log')
    ax[2].set_ylim(1e2)

    
ax[1].legend(loc='best', # 1
             ncol=1, title='Color Channel')
ax[1].set_xlabel('Time [s]')
ax[1].set_ylabel('PMT Rates [Hz]')
ax[1].set_xlim((plot_data.interp_time[plot_data.interp_mask][0],
                plot_data.interp_time[plot_data.interp_mask][-1]))
ax[1].xaxis.set_major_formatter(
    mdates.ConciseDateFormatter(ax[1].xaxis.get_major_locator())
)
ax[1].grid()

# ax[3].grid()
# ax[3].set_xlabel('Date')

# that's when the camera frame starts
#plt.axvline(timestamp, lw=1, color="k")
if False:
    ax[1].ylim(0, plot_data.raw_rate.T[plot_data.raw_mask].max()*1.05)
else:
    ax[1].set_yscale('log')
    ax[1].set_ylim(1e2)

fig.tight_layout()

In [None]:
fig.tight_layout()
fig.savefig(os.path.expanduser(f'~/Downloads/pmt_event-{timestamp:}.pdf'), format='pdf')