In [37]:
from lib.utils import *
from lib.models import *
from lib.ekyn import *
from lib.env import *
from lib.datasets import *

In [38]:
def fix_gaps(df):
    df = df.reset_index(drop=True)
    gaps = df[df['Time Stamp'].diff() > datetime.timedelta(seconds=10)]
    if len(gaps) == 0:
        return df
    gap = gaps.iloc[0]
    start = df.iloc[gap.name - 1,0]
    end = df.iloc[gap.name,0]
    upper = df.iloc[:gap.name]
    lower = df.iloc[gap.name:]
    start_ts = start.timestamp()
    end_ts = end.timestamp()
    number_of_epochs_to_add = int((end_ts-start_ts) // 10) - 1
    for i in range(number_of_epochs_to_add):
        upper = pd.concat([upper,pd.DataFrame([start + datetime.timedelta(seconds=(i+1)*10),'X'],index=df.columns).T])
    return pd.concat([upper,fix_gaps(lower)]).reset_index(drop=True)

In [43]:
def load_feature_label_pair_snezana(filename):
    fs = 500
    raw = read_raw_edf((f'../data/snezana_mice/{filename}.edf'),verbose=False)
    measurement_date = raw.info["meas_date"]
    eeg = raw.get_data(picks='EEG 1')[0]
    df = pd.read_excel(f'../data/snezana_mice/{filename}.xlsx')
    df = df.drop(0).reset_index(drop=True)
    df = fix_gaps(df)
    df.loc[df['Rodent Sleep'] == 'X','Rodent Sleep'] = NA
    df = df.fillna(method='ffill')
    start_time = df['Time Stamp'][0]
    end_time = df.iloc[-1,0]
    length = (end_time - start_time)
    times = [start_time + datetime.timedelta(seconds=10*i) for  i in range(int((length.days*86400 + length.seconds)/10)+1)]
    measurement_date = measurement_date.replace(tzinfo=None)
    offset = df.iloc[0,0] - measurement_date
    eeg = eeg[offset.seconds*500:]
    eeg = eeg[:len(times)*5000]
    eeg = from_numpy(eeg.reshape(-1, 5000)).float()
    y = one_hot(from_numpy(Categorical(df['Rodent Sleep']).codes.copy()).long()).float()
    return eeg,y

In [44]:
X,y = load_feature_label_pair_snezana('21-WT-1')

  warn("Workbook contains no default style, apply openpyxl's default")


In [45]:
model = Gandalf()
model.load_state_dict(torch.load(f=f'../models/spindle_gandalfs/gandalf_spindle_fold_00/best_model.pt',map_location='cpu'))
model.to(DEVICE)
criterion = torch.nn.CrossEntropyLoss()
dataloader = DataLoader(Windowset(X,y),batch_size=32,shuffle=False)
loss,report,y_true,y_pred,logits = evaluate(dataloader,model,criterion,DEVICE)

100%|██████████| 540/540 [00:20<00:00, 26.15it/s]
  _warn_prf(average, modifier, msg_start, len(result))


In [47]:
report

{'precision': 0.1681520061728395,
 'recall': 0.3333333333333333,
 'f1': 0.2235386134297547}

In [None]:
cm_grid(y_true,y_pred)

In [None]:

# all_metrics = pd.concat([all_metrics,pd.Series(metrics,name=f'{filename}')],axis=1)

In [None]:
y_true.dtype

In [None]:
def get_bout_statistics_for_predictions(pred):
    if pred.dtype is torch.float32:
        pred = pd.DataFrame(pred)
        pred.loc[pred[0] == 2,0] = 'W'
        pred.loc[pred[0] == 1,0] = 'S'
        pred.loc[pred[0] == 0,0] = 'P'
        pred = pred[0]
    bout_lengths = {
    'P':[],
    'S':[],
    'W':[],
    'X':[],
    'A':[]
    }
    transition_matrix = pd.DataFrame(np.zeros((5,5)),columns=['P','S','W','X','A'],index=['P','S','W','X','A'])

    current_state = 'A'
    current_length = 0
    for epoch in pred:
        transition_matrix.loc[current_state,epoch] += 1
        if(epoch != current_state):
            bout_lengths[current_state].append(current_length)
            current_state = epoch
            current_length = 0
        current_length += 1
    bout_lengths[current_state].append(current_length)
    bout_lengths.pop('X')
    bout_lengths.pop('A')
    total = {key:sum(bout_lengths[key])*10/60 for key in bout_lengths}
    average = {key:np.mean(bout_lengths[key])*10 for key in bout_lengths}
    counts = {key:len(bout_lengths[key]) for key in bout_lengths}
    
    return pd.DataFrame([pd.Series(total,name='total'),pd.Series(average,name='average'),pd.Series(counts,name='counts')])

In [None]:
stats = get_bout_statistics_for_predictions(y_true)
stats

In [None]:
stats = get_bout_statistics_for_predictions(y_pred)
stats

In [None]:
import plotly.express as px
fig = px.line(data_frame=pd.DataFrame([y_pred.numpy(),y_true.numpy()]).T)
fig.show(renderer='browser')

In [None]:
for filename in ids:
    print(filename)
    if filename == '22-Oct-A':
        continue
    fs = 500
    raw = read_raw_edf((f'../data/courtney_aug_oct_2022_baseline_recordings/1_raw_edf/{filename}.edf'),verbose=False)
    measurement_date = raw.info["meas_date"]
    eeg = raw.get_data(picks='EEG 1')[0]
    df = pd.read_excel(f'../data/courtney_aug_oct_2022_baseline_recordings/2_labels/CW {filename} Baseline.xls')
    df = df.drop(0).reset_index(drop=True)
    df = fix_gaps(df)
    df.loc[df['Label'] == 'X','Label'] = NA
    df = df.fillna(method='ffill')
    print(df['Start Time'].diff().value_counts())
    print(df['Label'].value_counts())
    start_time = df['Start Time'][0]
    end_time = df.iloc[-1,0]
    length = (end_time - start_time)
    times = [start_time + datetime.timedelta(seconds=10*i) for  i in range(int((length.days*86400 + length.seconds)/10)+1)]
    print(len(times))
    print(times[0])
    print(times[-1])
    eeg = raw.get_data(picks='EEG 1')[0]
    print(eeg.shape[0]/500)
    measurement_date = measurement_date.replace(tzinfo=None)
    offset = df.iloc[0,0] - measurement_date
    eeg = eeg[offset.seconds*500:]
    print(eeg.shape[0]/500)
    eeg = eeg[:len(times)*5000]
    print(eeg.shape[0]/500)
    eeg = from_numpy(eeg.reshape(-1, 5000)).float()
    print(eeg.shape)
    y = one_hot(from_numpy(Categorical(df['Label']).codes.copy()).long()).float()
    dataloader = DataLoader(Windowset(eeg,y),batch_size=32,shuffle=False)
    loss,metrics,y_true,y_pred,logits = evaluate(dataloader,model,criterion,DEVICE)
    all_metrics = pd.concat([all_metrics,pd.Series(metrics,name=f'{filename}')],axis=1)