In [2]:
import os
import pandas as pd
import plotly.express as px
from lib.env import *
def run_new_state_machine_on_thresholded_predictions(predictions):
    state = 0
    states = []
    puff_locations = []
    currentInterPuffIntervalLength = 0
    currentPuffLength = 0
    for i,smokingOutput in enumerate(predictions):
        states.append(state)
        if (state == 0 and smokingOutput == 0.0):
            # no action
            state = 0
        elif (state == 0 and smokingOutput == 1.0):
            # starting validating puff length
            state = 1
            currentPuffLength += 1
        elif (state == 1 and smokingOutput == 1.0):
            # continuing not yet valid length puff
            currentPuffLength += 1
            if (currentPuffLength > 14) :
                # valid puff length!
                state = 2
        elif (state == 1 and smokingOutput == 0.0):
            # never was a puff, begin validating end
            state = 3
            currentInterPuffIntervalLength += 1
        elif (state == 2 and smokingOutput == 1.0):
            # continuing already valid puff
            currentPuffLength += 1
        elif (state == 2 and smokingOutput == 0.0):
            # ending already valid puff length
            state = 4 # begin validating inter puff interval
            currentInterPuffIntervalLength += 1
        elif (state == 3 and smokingOutput == 0.0): 
            currentInterPuffIntervalLength += 1
            if (currentInterPuffIntervalLength > 49):
                # valid interpuff
                state = 0
                currentPuffLength = 0
                currentInterPuffIntervalLength = 0
        elif (state == 3 and smokingOutput == 1.0):
            # was validating interpuff for puff that wasn't valid
            currentPuffLength += 1
            currentInterPuffIntervalLength = 0
            if (currentPuffLength > 14) :
                # valid puff length!
                state = 2
            else:
                state = 1
        elif (state == 4 and smokingOutput == 0.0) :
            currentInterPuffIntervalLength += 1
            if (currentInterPuffIntervalLength > 49):
                # valid interpuff for valid puff
                state = 0
                currentPuffLength = 0
                currentInterPuffIntervalLength = 0
                puff_locations.append(i)
        elif (state == 4 and smokingOutput == 1.0):
            # back into puff for already valid puff
            currentInterPuffIntervalLength = 0
            currentPuffLength += 1
            state = 2
    states = states[1:] + [0]
    return states,puff_locations

In [6]:
DATASET_PATH = f'{DATA_PATH}/riley'
RECORDING_IDS = os.listdir(DATASET_PATH)
RECORDING_PATH = f'{DATASET_PATH}/{RECORDING_IDS[0]}/raw/{RECORDING_IDS[0]}.0.csv'
df = pd.read_csv(RECORDING_PATH,skiprows=1)
df = df.reset_index()
fs = 20
df.timestamp = (df.timestamp - df.timestamp[0])*1e-9
df['index'] = df['index']/(fs*60) # index in minutes

In [9]:
import torch
X = torch.from_numpy(df[['acc_x','acc_y','acc_z']].to_numpy())
x = X[:,0].unsqueeze(1)
y = X[:,1].unsqueeze(1)
z = X[:,2].unsqueeze(1)
xs = [x[:-99]]
ys = [y[:-99]]
zs = [z[:-99]]
for i in range(1,99):
    xs.append(x[i:i-99])
    ys.append(y[i:i-99])
    zs.append(z[i:i-99])
xs.append(x[99:])
ys.append(y[99:])
zs.append(z[99:])
xs = torch.cat(xs,axis=1).float()
ys = torch.cat(ys,axis=1).float()
zs = torch.cat(zs,axis=1).float()
X = torch.cat([xs,ys,zs],axis=1)

In [11]:
from lib.models import Casey1p1
model = Casey1p1()

In [12]:
y_pred = model(X)
df['y_pred'] = y_pred

  0%|          | 0/60564 [00:00<?, ?it/s]

100%|██████████| 60564/60564 [00:04<00:00, 14313.52it/s]


In [13]:
# see the difference in activation
fig = px.line(df.iloc[::2],y=['rawlabel','y_pred'])
fig.show(renderer='browser')

In [16]:
y_true = df['rawlabel'].to_list()
y_thresh_true = [1 if y > .85 else 0 for y in y_true]
df['label'] = y_thresh_true

In [17]:
y_thresh = [1 if y > .85 else 0 for y in y_pred]
df['y_thresh'] = y_thresh

In [18]:
fig = px.line(df.iloc[::2],y=['label','y_thresh'])
fig.show(renderer='browser')

In [19]:
y_state,puff_locations = run_new_state_machine_on_thresholded_predictions(y_thresh)

In [20]:
df['y_state'] = y_state
fig = px.line(df.iloc[::2],y=['state','y_state'])
fig.show(renderer='browser')

In [21]:
fig = px.line(df.iloc[::2],y=['state','y_state'])
for puff_loc in puff_locations:
    fig.add_vline(x=puff_loc)
fig.show(renderer='browser')

In [29]:
fig = px.line(df.iloc[::2],y=['acc_x','acc_y'])
for puff_loc in puff_locations:
    fig.add_vline(x=puff_loc)
fig.show(renderer='browser')