In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

PATH_TO_DIRECTORY_OF_DATA_TO_SCORE = f'/home/andrew/sleep/data_to_score'
MAPPING_OF_FILE_IDS_TO_CHANNEL_NAME = {
    "SPARC-10-M":"EEG 1", 
    "SPARC-11-M":"EEG 1",
    "SPARC-12-M":"EEG 1",
    "SPARC-13-M":"EEG 1",
    "SPARC-14-M":"EEG 1",
    "SPARC-15-M":"EEG 1",
    "SPARC-17-M":"EEG 2",
    "SPARC-18-M":"EEG 1",
    "SPARC-2-F":"EEG 1",
    "SPARC-4-F":"EEG 1",
    "SPARC-5-F":"Channel-2",
    "SPARC-8-F":"Channel-10",
    "SPARC-9-F":"Channel-2",
    "SPARC-10-F":"Channel-50",
    "SPARC-11-F":"Channel-18",
    "SPARC-12-F":"Channel-42",
    "SPARC-13-F":"EEG 1",
    "SPARC-14-F":"EEG 1",
    "SPARC-15-F":"EEG 1",
    "SPARC-16-F":"EEG 2"
}
IDS = sorted(list(set([id.split('.')[0] for id in os.listdir(PATH_TO_DIRECTORY_OF_DATA_TO_SCORE)])))
print(IDS)

In [None]:
from mne.io import read_raw_edf
import torch
id = IDS[1]
print(id)
zdb_path = f'{PATH_TO_DIRECTORY_OF_DATA_TO_SCORE}/{id}.zdb'
edf_path = f'{PATH_TO_DIRECTORY_OF_DATA_TO_SCORE}/{id}.edf'
if not os.path.exists(zdb_path):
    raise FileNotFoundError(zdb_path)
if not os.path.exists(edf_path):
    raise FileNotFoundError(edf_path)
eeg_channel = MAPPING_OF_FILE_IDS_TO_CHANNEL_NAME[id]
eeg = torch.from_numpy(read_raw_edf(edf_path).get_data(picks=[eeg_channel])[0]).float().reshape(-1,5000)

In [None]:
eeg.shape

In [None]:
PATH_TO_MODEL = f'/home/andrew/sleep/models/gandalfs/gandalf_0/best_model.pt'
from torch import nn
from torch.nn.functional import relu
DEVICE = 'cuda'
class ResidualBlock(nn.Module):
    def __init__(self,in_feature_maps,out_feature_maps,n_features) -> None:
        super().__init__()
        self.c1 = nn.Conv1d(in_feature_maps,out_feature_maps,kernel_size=8,padding='same',bias=False)
        self.bn1 = nn.LayerNorm((out_feature_maps,n_features),elementwise_affine=False)

        self.c2 = nn.Conv1d(out_feature_maps,out_feature_maps,kernel_size=5,padding='same',bias=False)
        self.bn2 = nn.LayerNorm((out_feature_maps,n_features),elementwise_affine=False)

        self.c3 = nn.Conv1d(out_feature_maps,out_feature_maps,kernel_size=3,padding='same',bias=False)
        self.bn3 = nn.LayerNorm((out_feature_maps,n_features),elementwise_affine=False)

        self.c4 = nn.Conv1d(in_feature_maps,out_feature_maps,1,padding='same',bias=False)
        self.bn4 = nn.LayerNorm((out_feature_maps,n_features),elementwise_affine=False)

    def forward(self,x):
        identity = x
        x = self.c1(x)
        x = self.bn1(x)
        x = relu(x)

        x = self.c2(x)
        x = self.bn2(x)
        x = relu(x)

        x = self.c3(x)
        x = self.bn3(x)
        x = relu(x)

        identity = self.c4(identity)
        identity = self.bn4(identity)

        x = x+identity
        x = relu(x)
        
        return x

class Frodo(nn.Module):
    """
    the little wanderer
    """
    def __init__(self,n_features,device='cuda') -> None:
        super().__init__()
        self.n_features = n_features
        self.block1 = ResidualBlock(1,8,n_features)
        self.block2 = ResidualBlock(8,16,n_features)
        self.block3 = ResidualBlock(16,16,n_features)

        self.gap = nn.AvgPool1d(kernel_size=n_features)
        self.fc1 = nn.Linear(in_features=16,out_features=3)
    def forward(self,x,classification=True):
        x = x.view(-1,1,self.n_features)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.gap(x)
        if(classification):
            x = self.fc1(x.squeeze())
            return x
        else:
            return x.squeeze()
        
class Gandalf(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.encoder = Frodo(n_features=5000,device=DEVICE).to(DEVICE)
        self.lstm = nn.LSTM(16,32,bidirectional=True)
        self.fc1 = nn.Linear(64,3)
    def forward(self,x_2d,classification=True):
        x_2d = x_2d.view(-1,9,1,5000)
        x = torch.Tensor().to(DEVICE)
        for t in range(x_2d.size(1)):
            xi = self.encoder(x_2d[:,t,:,:],classification=False)
            x = torch.cat([x,xi.unsqueeze(0)],dim=0)
        out,_ = self.lstm(x)
        if(classification):
            x = self.fc1(out[-1])
        else:
            x = out[-1]
        return x
model = Gandalf()
model.load_state_dict(torch.load(f=PATH_TO_MODEL,map_location='cpu'))
DEVICE = 'cuda'
model.to(DEVICE);

In [None]:
from torch.utils.data import Dataset,DataLoader
class FeatureSet(Dataset):
    def __init__(self,X):
        self.len = len(X)
        self.X = torch.cat([torch.zeros(4,5000),X,torch.zeros(4,5000)])

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        return self.X[idx:idx+9].flatten()
dataloader = DataLoader(dataset=FeatureSet(eeg),batch_size=32,shuffle=False)

In [None]:
from tqdm import tqdm

In [None]:
with torch.no_grad():
    y_pred = torch.Tensor()
    y_logits = torch.Tensor()
    for Xi in tqdm(dataloader):
        Xi = Xi.to(DEVICE)
        logits = model(Xi)
        y_logits = torch.cat([y_logits,torch.softmax(logits,dim=1).detach().cpu()])
        y_pred = torch.cat([y_pred,torch.softmax(logits,dim=1).argmax(axis=1).detach().cpu()])

In [None]:
import pandas as pd
y_pred = pd.DataFrame(y_pred.numpy()).astype(int)

y_pred[y_pred[0] == 0] = 'P'
y_pred[y_pred[0] == 1] = 'S'
y_pred[y_pred[0] == 2] = 'W'
csv_path = f'{PATH_TO_DIRECTORY_OF_DATA_TO_SCORE}/{id}.csv'

y_pred.to_csv(csv_path,index=False) # write to file because inference takes time and you can resume here

In [None]:
import sqlite3
from sqlite3 import Error
rename_dict = {'W':'Sleep-Wake', 'S':'Sleep-SWS', 'P':'Sleep-Paradoxical', 'X':''}
offset = 10e7#epoch time period

y_pred = pd.read_csv(csv_path)

try:
    conn = sqlite3.connect(zdb_path)
except Error as e:
    print(e)

# get recording start stop
cur = conn.cursor()
query = "SELECT value FROM internal_property WHERE key='RecordingStart'"
cur.execute(query)
result = cur.fetchall()
recording_start = int(result[0][0])
query = "SELECT value FROM internal_property WHERE key='RecordingStop'"
cur.execute(query)
result = cur.fetchall()
recording_stop = int(result[0][0])
length_ns = recording_stop - recording_start # ns
length_s = length_ns * 1e-7 # s
hh = length_s // 3600
mm = (length_s % 3600) // 60
ss = ((length_s % 3600) % 60)
print(hh,mm,ss,length_s)
print(recording_start)
print(recording_stop)

#drop this table - creates issues
query = "DROP TABLE IF EXISTS temporary_scoring_marker;"
cur.execute(query)

query = "SELECT name FROM sqlite_master WHERE type='table';"
cur.execute(query)

tables = [table[0] for table in cur.fetchall()]

if 'scoring_revision' not in tables:
    query = "CREATE TABLE scoring_revision (id INTEGER PRIMARY KEY, name TEXT, is_deleted INTEGER(1), tags TEXT, version INTEGER(8), owner TEXT, date_created INTEGER(8));"
    cur.execute(query)

if 'scoring_marker' not in tables:
    query = "CREATE TABLE scoring_marker (id INTEGER PRIMARY KEY, starts_at INTEGER(8), ends_at INTEGER(8), notes TEXT, type TEXT, location TEXT, is_deleted INTEGER(1), key_id INTEGER);"
    cur.execute(query)

if 'scoring_comment' not in tables:
    query = "CREATE TABLE scoring_comment (id INTEGER PRIMARY KEY, category TEXT, key TEXT, value TEXT);"
    cur.execute(query)

if 'scoring_key' not in tables:
    query = "CREATE TABLE scoring_key (id INTEGER PRIMARY KEY, date_created INTEGER(8), name TEXT, owner TEXT, type TEXT);"
    cur.execute(query)

if 'scoring_revision_to_comment' not in tables:
    query = "CREATE TABLE scoring_revision_to_comment (revision_id INTEGER(8), comment_id INTEGER(8));"
    cur.execute(query)

if 'scoring_revision_to_key' not in tables:
    query = "CREATE TABLE scoring_revision_to_key (revision_id INTEGER(8), key_id INTEGER(8));"
    cur.execute(query)

# delete first score before adding machine data
query = "DELETE FROM scoring_marker;"
cur.execute(query)

#delete first score before adding machine data
query = "DELETE FROM scoring_revision;"
cur.execute(query)

#delete first score before adding machine data
query = "DELETE FROM scoring_key;"
cur.execute(query)

#delete first score before adding machine data
query = "DELETE FROM scoring_revision_to_key;"
cur.execute(query)

query = f"""
    INSERT INTO scoring_revision 
    (name, is_deleted, version, date_created)
    VALUES 
    ('LSTM', 0, 0, {recording_start});
    """ 
cur.execute(query)

query = f"""
    INSERT INTO scoring_key 
    (date_created, type)
    VALUES 
    ({recording_start},'Automatic');
    """ 
cur.execute(query)

query = f"""
    INSERT INTO scoring_revision_to_key 
    (revision_id, key_id)
    VALUES 
    (1,1);
    """ 
cur.execute(query)

#get keyid of scoring
query = "SELECT MAX(id) FROM scoring_revision WHERE name='LSTM'"
cur.execute(query)
keyid = cur.fetchall()[0][0]

start_time = recording_start - (recording_start % 100000000)
stop_time = 0
# #insert new epochs with scoring into the table
for i in range(len(y_pred)):
    #calculate epoch
    if i != 0:
        start_time = stop_time
    stop_time = start_time+offset

    score = rename_dict[y_pred.at[i,'0']]
    #insert epoch
    query = f"""
            INSERT INTO scoring_marker 
            (starts_at, ends_at, notes, type, location, is_deleted, key_id)
            VALUES 
            ({start_time}, {stop_time}, '', '{score}', '', 0, {keyid});
            """ 
    cur.execute(query)

conn.commit()
conn.close()