# This notebook creates a kind of event view for camera pictures and the pmt rates. 

In [None]:
%matplotlib notebook

# This examples shows who to download files from the ONC server
import glob
import os

import pandas as pd
import numpy as np
import datetime
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

import strawb
import strawb.sensors.pmtspec
import strawb.tools

import plotly
import plotly.graph_objs as go
import plotly.express as px
import plotly.io as pio
pio.templates.default = "plotly_white"

import random

plt.rcParams.update({"text.usetex": False})  # fix time fmt of x_ticks

# Get DB and update

In [None]:
# in case execute db.load_entire_db_from_ONC() to load the entire db
db = strawb.SyncDBHandler()  # loads the db
db.load_onc_db_update(output=True, save_db=True)

# Select 'event' by timestamp

In [None]:
# select the time of interest
timestamp = np.datetime64('2021-09-04T20:42:58')  # wide peak, no image
timestamp = np.datetime64('2021-09-04T23:44:09')  # blinking
timestamp = np.datetime64('2021-09-24T21:57')  # over-saturate 'no filter' channel

# Prepare the files

In [None]:
# mask by device
mask = db.dataframe.deviceCode == 'TUMPMTSPECTROMETER001'
# mask &= db.dataframe.dataProductCode == 'PMTSD'

# and add mask for synced files only
# mask &= db.dataframe.synced

# and add mask by 'measurement_type'
# mask &= db.dataframe.file_version > 0

## select 1h frame around one bioluminescence event
# timestamp = np.datetime64('2021-09-04T23:44:09')
time_from = np.datetime64(timestamp - np.timedelta64(5, "m"), "h")  # round down to full hour
time_to = np.datetime64(timestamp + np.timedelta64(5, "m"), "h")  # round down to full hour

# files which cover time_from_mask
mask_from = db.dataframe.dateFrom <= pd.Timestamp(time_from, tz="UTC")
mask_from &= db.dataframe.dateTo >= pd.Timestamp(time_from, tz="UTC")
# files which cover time_to_mask
mask_to = db.dataframe.dateFrom <= pd.Timestamp(time_to, tz="UTC")
mask_to &= db.dataframe.dateTo >= pd.Timestamp(time_to, tz="UTC")

# files which cover time_to_mask
mask_to = db.dataframe.dateFrom >= pandas.Timestamp(time_from, tz="UTC")
mask_to &= db.dataframe.dateTo <= pandas.Timestamp(time_to, tz="UTC")

mask &= mask_from | mask_to

### show selected file from the DB - here its 1 file only
db.dataframe[mask]

## In case not all files are synced, do it here

In [None]:
### this will download the selected file
if not db.dataframe.synced[mask].all():
    db.update_db_and_load_files(
        db.dataframe[mask],
        output=True,  # print output to console
        download=True,  # download the files
        #save_db=True
    )
    db.save_db()

## Load the PMTSpec and Camera file

In [None]:
# select the PMTSpec file(s) -> dataProductCode == 'PMTSD'
item = db.dataframe[mask & (db.dataframe.dataProductCode == 'PMTSD')]

try: # if the pmtspec file is still open
    pmtspec.file_handler.close()
except:
    pass
    
# generate a virtual hdf5 to combine the datasets if there are multiple files selected
if len(item) > 1:
    vhdf5 = strawb.VirtualHDF5('PMTSD_event_view.hdf5', item.fullPath.to_list())  
    file_name = vhdf5.file_name
else:
    file_name = item.fullPath[0]
    
pmtspec = strawb.sensors.PMTSpec(file=file_name)
pmtspec.file_handler.file_attributes  # show the file attributes, time in seconds since epoch (1.1.1970) UTC

In [None]:
# select the Camera file(s) -> dataProductCode == 'MSSCD'
item = db.dataframe[mask & (db.dataframe.dataProductCode == 'MSSCD')]

try:  # if the camera file is still open
    camera.file_handler.close()
except:
    pass
    
# generate a virtual hdf5 to combine the datasets if there are multiple files selected
if len(item) > 1:
    vhdf5 = strawb.VirtualHDF5('MSSCD_event_view.hdf5', item.fullPath.to_list())  
    file_name = vhdf5.file_name
else:
    file_name = item.fullPath[0]
    
camera = strawb.sensors.Camera(file=file_name)
camera.file_handler.file_attributes  # show the file attributes, time in seconds since epoch (1.1.1970) UTC

In [None]:
# select the Module file(s) -> dataProductCode == 'MSSCD'
item = db.dataframe[mask & (db.dataframe.dataProductCode == 'SMRD')]

try:  # if the camera file is still open
    module.file_handler.close()
except:
    pass
    
# generate a virtual hdf5 to combine the datasets if there are multiple files selected
if len(item) > 1:
    vhdf5 = strawb.VirtualHDF5('SMRD_event_view.hdf5', item.fullPath.to_list())  
    file_name = vhdf5.file_name
else:
    file_name = item.fullPath[0]
    
module = strawb.sensors.Module(file=file_name)
module.file_handler.file_attributes  # show the file attributes, time in seconds since epoch (1.1.1970) UTC

# prepare colors

In [None]:
# prepare colors
### use a nice colormap, or use wavelength_to_rgb for conversion
pmtspec.pmt_meta_data.add_colors(
        plt.cm.get_cmap('viridis'),
        None,
        plt.cm.get_cmap("gist_heat") # reddish color
    )
pmtspec.pmt_meta_data.channel_meta_array

# Class which has data for plot

In [None]:
class PlotData:
    def __init__(self, pmtspec, camera, index=27, 
                 #timestamp=np.datetime64('2021-09-04T23:44:09')
                 module=None,
                ):
        #self.timestamp = timestamp
        self.pmtspec = pmtspec
        self.camera = camera
        
        # get Camera image
        # get the closest frame of the camera in time
        self.index = index
        self.timestamp = camera.file_handler.time.asdatetime()[self.index]
        self.images_rgb = camera.images.load_rgb(None, subtract_dark=True)[index]
        self.cam_time, self.cam_capture = self.fill_where_exposed(camera, self.index)
        
        # select the period of interest and mask data
        ### select a couple of seconds around bioluminescence events
        ## you can adjust this time for zoom-in/out
        exposure_time = np.timedelta64(int(camera.file_handler.exposure_time[self.index]*1e3), "ms")
        start_time = self.timestamp + np.timedelta64(2, "s") # - np.timedelta64(2, "m")
        end_time = self.timestamp + np.timedelta64(2, "s") + exposure_time
        
        # get PMTSpec data
        # raw data, time, mask for the selected time, rate
        print('raw rates')
        self.raw_time = pmtspec.trb_rates.time_middle
        self.raw_mask = (self.raw_time >= start_time) & (self.raw_time <= end_time)
        self.raw_rate = pmtspec.trb_rates.rate

        # interpolated data, time, mask for the selected time, rate
        print('interp frequency')
        pmtspec.trb_rates.interp_frequency = 5
        self.interp_time = pmtspec.trb_rates.interp_time
        self.interp_mask = (self.interp_time >= start_time) & (self.interp_time <= end_time)
        self.interp_rate = pmtspec.trb_rates.interp_rate

        # interpolated data, time, mask for the selected time, rate
        print('interp frequency slow')
        pmtspec.trb_rates.interp_frequency = 2
        self.interp_time_slow = pmtspec.trb_rates.interp_time
        self.interp_rate_slow = pmtspec.trb_rates.interp_rate
        self.interp_active_ratio_slow = pmtspec.trb_rates.interp_active_ratio
        
        # power measurements
        print('power measurements')
        self.power_devices_list = []
        if module is not None:
            for i in module.power.all_devices_list:
                time_i = strawb.tools.asdatetime(i.time)
                mask_i = self.interp_time_slow.min() <= time_i
                mask_i &= self.interp_time_slow.max(fill_value=0) >= time_i  # fill_value=0 <-> np.nan -> 0
                dev_i = strawb.sensors.module.power.PowerDevice(name = i.name, 
                                                                time = time_i[mask_i],
                                                                current = i.current[mask_i],
                                                                voltage = i.voltage[mask_i])
                self.power_devices_list.append(dev_i)
        
    # Camera fill where exposed
    @staticmethod
    def fill_where_exposed(camera, index=None):
        """Create an array which is needed to fill the areas where the camera exposed images.

        plt.fill_between(strawb.tools.asdatetime(time), 0, 1, 
                         where=capture>0,  # or >1 for the selected image
                         color='gray', 
                         alpha=0.2, 
                         label='exposure images',
                         transform=plt.gca().get_xaxis_transform())

        PARAMETER
        ---------
        camera: strawb.sensor.Camera
        index: int, optional
            to maks a selected frame
        RETURN
        ------
        time: ndarray, float
            time since epoch in seconds
        capture:
            the array which indicates 1 if a image was during exposure or 0 otherwise.
        """

        start_time = camera.file_handler.time[:]
        stop_time = camera.file_handler.time[:] + camera.file_handler.exposure_time


        # points for [dis-, en-, en-, dis-]-able
        time = np.array((start_time, start_time, stop_time, stop_time)).T
        capture = np.zeros_like(time, dtype=int)
        capture[:,1:3] = 1

        if index is not None:
            capture[index, 1:3] = 2  # the selected frame

        return time.flatten(), capture.flatten()

# get the index of the frame from the timestamp

In [None]:
index = np.argmin(np.abs(camera.file_handler.time.asdatetime()[:] - timestamp))

# The plot includes the indexes of the other frames. 
# Replace it here and run the code again to select the frame
# index = 35
print(f'Take picture with index: {index} at {camera.file_handler.time.asdatetime()[index]}')
plot_data = PlotData(pmtspec=pmtspec, 
                     camera=camera, 
                     index=index, #  timestamp=np.datetime64('2021-09-04T23:40:49')
                     module=module,
                    )

# Plot

In [None]:
with_timeline = True

figsize = np.array([14., 5.]) / 2.54  # cm -> inc
nrows = 1
height_ratios = [2,]
    
if with_timeline:
    figsize += np.array([0., 7.5]) / 2.54  # cm -> inc
    nrows += 3
    height_ratios.extend([1, 1, 1])
    
print(height_ratios)

# create the figure
fig = plt.figure(figsize=figsize)

gs = plt.GridSpec(nrows, 2, 
                  width_ratios=[1, 2], 
                  height_ratios=height_ratios
                  )

ax = [plt.subplot(gs[0, 0]), plt.subplot(gs[0, 1])]
      
if with_timeline:
    ax_row_1 = plt.subplot(gs[1,:])
    ax.append(ax_row_1)
    ax.append(plt.subplot(gs[2,:], sharex = ax_row_1))
    ax.append(plt.subplot(gs[3,:], sharex = ax_row_1))

### The image
# ax[0].set_title('Image of Event', y=-0.1)
ax[0].set_xlabel('Image of Event')
# ax[0].axis('off')
ax[0].axes.xaxis.set_ticks([])
ax[0].axes.yaxis.set_ticks([])
ax[0].imshow(plot_data.images_rgb/ 2 ** 16)  # images_rgb is a array, select the first frame

### The PMTs
# select channels to plot
mask = np.zeros_like(pmtspec.pmt_meta_data.channel_meta_array['wavelength'], dtype=bool)

# for i in [  0, 350, 400, 425, 450, 460, 470, 480, 492, 510, 525, 550]: # all
for i in [0, 400, 450, 480, 492, 510, 550]: 
    mask |= pmtspec.pmt_meta_data.channel_meta_array['wavelength'] == i

# plt.title(timestamp)
for ch in pmtspec.pmt_meta_data.channel_meta_array[mask]:
    _ = ax[1].plot(
        plot_data.interp_time[plot_data.interp_mask],
        plot_data.interp_rate[ch["index"], plot_data.interp_mask], 
        color = ch["color"], #wavelength_to_rgb(ch, gamma=0.4),
        label = ch["label"],
        #marker="o", lw=0, alpha=0.3, ms=5
    )
    
    if with_timeline:
        _ = ax[2].plot(
            plot_data.interp_time_slow,
            plot_data.interp_rate_slow[ch["index"]], 
            color = ch["color"], #wavelength_to_rgb(ch, gamma=0.4),
    #         label = ch["label"],
            #marker="o", lw=0, alpha=0.3, ms=5
        )
    
    if with_timeline:
        _ = ax[3].plot(
            plot_data.interp_time_slow,
            plot_data.interp_active_ratio_slow[ch["index"]], 
            color = ch["color"], #wavelength_to_rgb(ch, gamma=0.4),
    #         label = ch["label"],
            #marker="o", lw=0, alpha=0.3, ms=5
        )  
        
    ### plot the RAW Data as points
#     _ = ax[1].plot(
#         plot_data.raw_time[plot_data.raw_mask],
#         plot_data.raw_rate[ch["index"], plot_data.raw_mask], 
#         color = ch["color"], #wavelength_to_rgb(ch, gamma=0.4),
#         # label = ch["label"],
#         #alpha=0.3,
#         marker="o", lw=0, alpha=0.01, ms=.5
#     )

# # to show readout frequency 
# ax[3].plot(
#     strawb.tools.asdatetime((pmtspec.file_handler.counts_time[:-1]+pmtspec.file_handler.counts_time[1:])*.5),
#     1./np.diff(pmtspec.file_handler.counts_time[:]), 
#     color = 'gray', #wavelength_to_rgb(ch, gamma=0.4),
# #         label = ch["label"],
#     # marker="o", lw=0, alpha=0.3, ms=5
#     )

# ax[3].vlines(plot_data.interp_time_slow.data[plot_data.interp_time_slow.mask],
#            ymin=0, ymax=1,
#             transform=ax[3].get_xaxis_transform(),
#              color = 'steelblue',
#            )

# Camera
if with_timeline:
    ax[2].fill_between(strawb.tools.asdatetime(plot_data.cam_time), 0, 1, 
                       where=plot_data.cam_capture>0,
                       color='gray', 
                       alpha=0.2, 
                       label='exposure images',
                       transform=ax[2].get_xaxis_transform())
    ax[2].fill_between(strawb.tools.asdatetime(plot_data.cam_time), 0, 1, 
                       where=plot_data.cam_capture>1,
                       color='steelblue', 
                       alpha=0.5, 
                       label='selected image',
                       transform=ax[2].get_xaxis_transform())

    # print the index to the plots
    # the x coords of this transformation are data, and the
    # y coord are axes
    import matplotlib.transforms as transforms
    trans = transforms.blended_transform_factory(ax[2].transData, ax[2].transAxes)
    image_index_time = strawb.tools.asdatetime((plot_data.cam_time[::4]+plot_data.cam_time[2::4])*.5)
    # add a text for the Image Index - alternative to second xaxis
#     for i, t_i in enumerate(image_index_time):
#         if i%2 == 0:
#             ax[2].text(t_i, 1., i, fontsize=10, horizontalalignment='center', verticalalignment='bottom',
#                        transform=trans,
#                       )
    
    # add a second xaxis for the Image Index
    ax_x2 = ax[2].secondary_xaxis("top")
    xticks_pos = ax[2].convert_xunits(image_index_time)
    xticks_lab = np.arange(len(image_index_time))
    ax_x2.set_xticks(xticks_pos, minor=True)
    ax_x2.set_xticks(xticks_pos[::5], minor=False)
    ax_x2.set_xticklabels(xticks_lab[::5], minor=False)
    ax_x2.set_xlabel('Image Index')
        
    ax[2].set_ylabel('PMT Rates [Hz]')
    ax[2].legend(loc=1)
    ax[2].set_xlim((plot_data.interp_time_slow[0],
                    plot_data.interp_time_slow[-1]))
    ax[2].grid()
    
    ax[2].set_yscale('log')
    ax[2].set_ylim(1e2)
    
    ax[3].grid()
    ax[3].set_ylim((-0.01))
    ax[3].set_ylabel('Active read ratio')
    
    # power plot
    for i in plot_data.power_devices_list:
        ax[4].plot(i.time,
                 i.watt,
                 label=module.power.dev_map(i.name),
                )

    ax[4].set_ylabel('Power consumption [W]')
    ax[4].legend(loc='upper right', ncol=2)
    ax[4].grid()
    
    ax[2].xaxis.set_major_formatter(
        mdates.ConciseDateFormatter(ax[2].xaxis.get_major_locator())
    )
    
ax[1].legend(loc='best', # 1
             ncol=1, title='Color Channel')
ax[1].set_xlabel('Time [s]')
ax[1].set_ylabel('PMT Rates [Hz]')
ax[1].set_xlim((plot_data.interp_time[plot_data.interp_mask][0],
                plot_data.interp_time[plot_data.interp_mask][-1]))
ax[1].xaxis.set_major_formatter(
    mdates.ConciseDateFormatter(ax[1].xaxis.get_major_locator())
)
ax[1].grid()

# ax[3].grid()
# ax[3].set_xlabel('Date')

# that's when the camera frame starts
#plt.axvline(timestamp, lw=1, color="k")
if False:
    ax[1].ylim(0, plot_data.raw_rate.T[plot_data.raw_mask].max()*1.05)
else:
    ax[1].set_yscale('log')
    ax[1].set_ylim(1e2)

In [None]:
import time
for i in range(5):
    fig.tight_layout()
    time.sleep(1)

In [None]:
fig.tight_layout()

In [None]:
fig.savefig(os.path.expanduser(f'~/Downloads/pmt_event-{str(timestamp).replace(":","-")}.pdf'), format='pdf')
fig.savefig(os.path.expanduser(f'~/Downloads/pmt_event-{str(timestamp).replace(":","-")}.png'), 
            format='png', facecolor='white')

# RANDOM STUFF
should be deleted probably

In [None]:
dt_f = np.diff(pmtspec.file_handler.counts_time[:])
dt = np.diff(pmtspec.trb_rates.rate_time)

t_f = pmtspec.file_handler.counts_time[:].astype(float)
t = pmtspec.trb_rates.rate_time[:]

In [None]:
plt.figure()
plt.hist(dt_f, bins=np.arange(0, 0.005, 0.00001))
plt.yscale('log')
plt.grid()

In [None]:
def f(x, y_0, m):
    return m*x + y_0
popt, pcov = scipy.optimize.curve_fit(f, xdata=t, ydata=t_f, p0=[t_f[0],1.])

In [None]:
popt[0], t_f[0], popt[0]-t_f[0]

In [None]:
import scipy.optimize

def f(x, y_0, m):
    return m*x + y_0

p0=[0.,1.]
# popt, pcov = scipy.optimize.curve_fit(f, xdata=t[:100], ydata=t_f[:100], absolute_sigma=True, p0=p0)

# max_i = 
plt.figure()

plt.plot(t, f(t,*p0), '-', color='gray')
plt.plot(t, t_f-t_f[0], 'o', alpha=.2, ms=2)
plt.grid()

popt, pcov = scipy.optimize.curve_fit(f, xdata=t, ydata=t_f-t_f[0], p0=p0)
print(popt)
plt.plot(t, f(t,*popt), '-')

plt.figure()
plt.plot(t, f(t,*popt)-(t_f-t_f[0]), '-')

In [None]:
plt.figure()
plt.plot(t, t_f, '-')

In [None]:
plt.figure()
a = f(t,*popt)-(t_f-t_f[0])

plt.plot(t, a, '-')
plt.axhline(popt[0], color='k')
plt.grid()
print(a.max())

In [None]:
popt[0]+t_f[0]

In [None]:
plt.figure()

plt.plot(t, t_f-t_f[0], 'o', alpha=.2, ms=2)
plt.grid()

import scipy.optimize

In [None]:
plt.figure()

hist, bin_edges = np.histogram(1./dt, bins=np.linspace(0,1e4,int(1e4)), 
                               #density=True
                              )
plt.plot(bin_edges, [0, *hist], label='Data')

hist, bin_edges = np.histogram(1./dt_f, bins=np.linspace(0,1e4,int(1e4)), 
                               #density=True
                              )
plt.plot(bin_edges, [0, *hist], label='Data')
plt.yscale('log')

In [None]:
np.argwhere(dt == dt.min()), 
i = dt_f.argmin()
i, dt_f[i], dt[i], dt.min()

In [None]:
plt.figure()

plt.plot(dt, dt_f, 'o', alpha=.2, ms=2)
plt.grid()

In [None]:
plt.figure()

plt.plot(1./dt, 1./dt_f, 'o', alpha=.2, ms=2)

In [None]:
plt.figure()
mask = dc==dc.min()
f_f = 1./dt_f[mask]
hist, bin_edges = np.histogram(f_f , np.linspace(0, f_f.max(),int(f_f.max()+1)), 
                               #density=True
                              )
plt.plot(bin_edges, [0, *hist],)
#plt.yscale('log')

In [None]:
dc = np.diff(pmtspec.file_handler.counts_ch0)

In [None]:
dc.max()

In [None]:
np.linspace(0,dc.max(),int(dc.max()+1))

In [None]:
plt.figure()

hist, bin_edges = np.histogram(dc, bins=np.linspace(0,dc.max(),int(dc.max()+1)), 
                               #density=True
                              )
plt.plot(bin_edges, [0, *hist],)
plt.yscale('log')

In [None]:
plt.figure()

hist, bin_edges = np.histogram(dt, bins=10000, #bins=np.linspace(0,1e4,int(1e4)), 
                               #density=True
                              )
plt.plot(bin_edges, [0, *hist], label='Data')
plt.yscale('log')
plt.grid()