In [1]:
import torch
import os
import torch
from torch import nn
from torch.nn.functional import relu
from mne.io import read_raw_edf
from tqdm import tqdm
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

class ResidualBlock(nn.Module):
    def __init__(self,in_feature_maps,out_feature_maps,n_features) -> None:
        super().__init__()
        self.c1 = nn.Conv1d(in_feature_maps,out_feature_maps,kernel_size=8,padding='same',bias=False)
        self.bn1 = nn.LayerNorm((out_feature_maps,n_features),elementwise_affine=False)

        self.c2 = nn.Conv1d(out_feature_maps,out_feature_maps,kernel_size=5,padding='same',bias=False)
        self.bn2 = nn.LayerNorm((out_feature_maps,n_features),elementwise_affine=False)

        self.c3 = nn.Conv1d(out_feature_maps,out_feature_maps,kernel_size=3,padding='same',bias=False)
        self.bn3 = nn.LayerNorm((out_feature_maps,n_features),elementwise_affine=False)

        self.c4 = nn.Conv1d(in_feature_maps,out_feature_maps,1,padding='same',bias=False)
        self.bn4 = nn.LayerNorm((out_feature_maps,n_features),elementwise_affine=False)

    def forward(self,x):
        identity = x
        x = self.c1(x)
        x = self.bn1(x)
        x = relu(x)

        x = self.c2(x)
        x = self.bn2(x)
        x = relu(x)

        x = self.c3(x)
        x = self.bn3(x)
        x = relu(x)

        identity = self.c4(identity)
        identity = self.bn4(identity)

        x = x+identity
        x = relu(x)
        
        return x
    
class Frodo(nn.Module):
    def __init__(self,n_features) -> None:
        super().__init__()
        self.n_features = n_features
        self.block1 = ResidualBlock(1,8,n_features)
        self.block2 = ResidualBlock(8,16,n_features)
        self.block3 = ResidualBlock(16,16,n_features)

        self.gap = nn.AvgPool1d(kernel_size=n_features)
        self.fc1 = nn.Linear(in_features=16,out_features=3)
    def forward(self,x,classification=True):
        x = x.view(-1,1,self.n_features)
        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.gap(x)
        if(classification):
            x = self.fc1(x.squeeze())
            return x
        else:
            return x.squeeze()
        
class Gandalf(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.encoder = Frodo(n_features=5000)
        self.lstm = nn.LSTM(16,32,bidirectional=True)
        self.fc1 = nn.Linear(64,3)
    def forward(self,x_2d,classification=True):
        x_2d = x_2d.view(-1,9,1,5000)
        x = []
        for t in range(x_2d.size(1)):
            xi = self.encoder(x_2d[:,t,:,:],classification=False)
            x.append(xi.unsqueeze(0))
        x = torch.cat(x)
        out,_ = self.lstm(x)
        if(classification):
            x = self.fc1(out[-1])
        else:
            x = out[-1]
        return x
    
class EEGDataset(torch.utils.data.Dataset):
    def __init__(self,X):
        self.len = len(X)
        self.X = torch.cat([torch.zeros(4,5000),X,torch.zeros(4,5000)])

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        return self.X[idx:idx+9].flatten()

In [12]:
def score_recording(data_dir,id,model_path,eeg_ch_name,device):
    if not os.path.exists(model_path):
        raise FileExistsError(model_path)

    edf_filepath = f'{data_dir}/{id}.edf'

    model = Gandalf()
    model.load_state_dict(torch.load(model_path,map_location='cpu'))

    raw = read_raw_edf(input_fname=edf_filepath)
    data = raw.get_data(picks=eeg_ch_name)
    eeg = torch.from_numpy(data[0]).float()
    eeg = eeg.view(-1,5000)

    model.eval()
    model.to(device)

    with torch.no_grad():
        dataloader = DataLoader(EEGDataset(eeg),batch_size=32)
        logits = torch.cat([model(Xi.to(device)).cpu() for Xi in tqdm(dataloader)])
        y_pred = logits.softmax(dim=1).argmax(axis=1)

    pd.DataFrame(y_pred,columns=['y_pred']).to_csv(f"{edf_filepath.replace('.edf','_y_pred')}.csv",index=False)
    pd.DataFrame(logits).to_csv(f"{edf_filepath.replace('.edf','_logits')}.csv",index=False)

In [15]:
device = 'mps'
model_path = 'aurora-sleep-staging/gandalf.pt'
data_dir = f'data'

ids  = [file.replace('.edf','') for file in os.listdir(data_dir) if file.endswith('.edf')]
edfs = [file for file in os.listdir(data_dir) if file.endswith('.edf')]
zdbs = [file for file in os.listdir(data_dir) if file.endswith('.zdb')]

print(data_dir,ids)
assert len(ids) == len(edfs)
assert len(ids) == len(zdbs)

id_to_zdb_filename = {}
for id in ids:
    corresponding_zdb_filename_for_id = [zdb for zdb in zdbs if id in zdb]
    assert len(corresponding_zdb_filename_for_id) <= 1, f"Multiple zdb files found for id {id}: {corresponding_zdb_filename_for_id}"
    id_to_zdb_filename[id] = corresponding_zdb_filename_for_id[0]

for id in ids:
    zdb_filename = id_to_zdb_filename[id]
    edf_filename = f'{id}.edf'
    print(id)
    eeg_ch_name = "EEG 1"
    score_recording(data_dir=data_dir,id=id,model_path=model_path,eeg_ch_name=eeg_ch_name,device=device)

data ['EKyn Sleep and Sleep Dep 24-Jun PD35.24-Jun-A1.20240806200742']
EKyn Sleep and Sleep Dep 24-Jun PD35.24-Jun-A1.20240806200742
Extracting EDF parameters from /Users/andrew/neuroscore/data/EKyn Sleep and Sleep Dep 24-Jun PD35.24-Jun-A1.20240806200742.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...


  return F.conv1d(input, weight, bias, self.stride,
100%|██████████| 810/810 [00:24<00:00, 33.42it/s]


In [64]:
device = 'cuda'
model_path = 'gandalf.pt'
latex = """"""
latex_figs = """
\\documentclass{standalone}
\\usepackage{pgf}
\\standaloneenv{pgfpicture}
\\begin{document}
"""
fig_index = 1

for subdir in os.listdir('data'):
    data_dir = f'data/{subdir}'
    ids  = [file.replace('.edf','') for file in os.listdir(data_dir) if file.endswith('.edf')]
    edfs = [file for file in os.listdir(data_dir) if file.endswith('.edf')]
    zdbs = [file for file in os.listdir(data_dir) if file.endswith('.zdb')]
    print(data_dir,ids)
    if len(ids) != len(edfs) or len(ids) != len(zdbs) or len(edfs) != len(zdbs):
        print('big problem')
    id_to_zdb_filename = {}
    for id in ids:
        corresponding_zdb_filename_for_id = [zdb for zdb in zdbs if id in zdb]
        if len(corresponding_zdb_filename_for_id) != 1:
            print('big problem')
        id_to_zdb_filename[id] = corresponding_zdb_filename_for_id[0]
    for id in ids:
        zdb_filename = id_to_zdb_filename[id]
        edf_filename = f'{id}.edf'
        print(id)
        if id == "24-Jun-B1":
            eeg_ch_name = 'EEG 2'
        else:
            eeg_ch_name = 'EEG 1'

        y_pred = torch.from_numpy(pd.read_csv(f'{data_dir}/{id}_y_pred.csv').values.flatten())

        fig,ax = plt.subplots(ncols=2,nrows=2,figsize=(9,9))
        sns.barplot(x=['P','S','W'],y=y_pred.bincount()/len(y_pred),ax=ax[0,0])
        ax[0,0].set_xlabel('Sleep Stage')
        ax[0,0].set_ylabel('Proportion')
        ax[0,0].set_title('from all epochs')

        zdb_filename = id_to_zdb_filename[id]
        edf_filename = f'{id}.edf'
        raw = read_raw_edf(input_fname=f'{data_dir}/{id}.edf')
        data = raw.get_data(picks=eeg_ch_name)
        eeg = torch.from_numpy(data[0]).float()
        eeg = eeg.view(-1,5000)

        sns.histplot(eeg.std(dim=1),label=eeg_ch_name,ax=ax[0,1])
        ax[0,1].set_xlim([0,.0002])
        ax[0,1].set_title('from all epochs')

        # Remove the subplots in the second row
        fig.delaxes(ax[1, 0])
        fig.delaxes(ax[1, 1])
        # Add a new subplot that spans the second row
        ax_combined = fig.add_subplot(2, 1, 2)  # 2 rows, 1 column, subplot 2

        logits = torch.from_numpy(pd.read_csv(f'{data_dir}/{id}_logits.csv').values).float()[:1000]
        softmaxed_logits = logits.softmax(dim=1)
        stacked_data = softmaxed_logits.T
        ax_combined.plot(softmaxed_logits.max(dim=1).values,linewidth=.3,color='black',label='Network Confidence')
        ax_combined.stackplot(range(stacked_data.shape[1]),  # x-axis should correspond to each sample or step
                    stacked_data[0], 
                    stacked_data[1], 
                    stacked_data[2],labels=['Paradoxical', 'Slow Wave', 'Wakefulness'])
        ax_combined.set_title('from first 1000 epochs')
        ax_combined.margins(x=0,y=0)
        ax_combined.legend()
        plt.savefig(f'{data_dir}/{id}.pgf',bbox_inches='tight')
        plt.close()

        latex += f"""
        \\subsection*{{ {data_dir} {id} }}
        \\begin{{center}}
        \\begin{{adjustbox}}{{width=1\\textwidth}}
        \includegraphics[page={fig_index}]{{figs.pdf}}
        \\end{{adjustbox}}
        \\end{{center}}
        \\large\\textbf{{Length (epochs): {len(eeg):.3f}}}\\\\
        \\textbf{{Length (hours): {len(eeg)*(10/3600):.3f}}}\\\\
        \\textbf{{Average Network Confidence: {softmaxed_logits.max(dim=1).values.mean():.3f}}}\\\\
        \\textbf{{Paradoxical Proportion: {(y_pred.bincount()/len(y_pred))[0]:.3f}}}\\\\
        \\textbf{{Slow-Wave Proportion: {(y_pred.bincount()/len(y_pred))[1]:.3f}}}\\\\
        \\textbf{{Wakefulness Proportion: {(y_pred.bincount()/len(y_pred))[2]:.3f}}}\\\\
        """
        latex_figs += f"""\\input{{../{data_dir}/{id}.pgf}}
"""
        fig_index += 1
# Specify the filename where you want to save the LaTeX content
latex_figs+="""\\end{document}"""

# Open the file in write mode
with open("report/figs.tex", 'w') as file:
    # Write the LaTeX docstring to the file
    file.write(latex_figs)
# Open the file in write mode
with open("report/insert.tex", 'w') as file:
    # Write the LaTeX docstring to the file
    file.write(latex)

data/24-Jun PD49 ['24-Jun-D2', '24-Jun-A1', '24-Jun-D1', '24-Jun-A2', '24-Jun-B2', '24-Jun-B1']
24-Jun-D2
Extracting EDF parameters from /home/andrew/aurora-sleep-staging/data/24-Jun PD49/24-Jun-D2.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
24-Jun-A1
Extracting EDF parameters from /home/andrew/aurora-sleep-staging/data/24-Jun PD49/24-Jun-A1.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
24-Jun-D1
Extracting EDF parameters from /home/andrew/aurora-sleep-staging/data/24-Jun PD49/24-Jun-D1.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
24-Jun-A2
Extracting EDF parameters from /home/andrew/aurora-sleep-staging/data/24-Jun PD49/24-Jun-A2.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
24-Jun-B2
Extracting EDF parameters from /home/andrew/aurora-sleep-staging/data/24-Jun PD49/24-Jun-B2.edf...
EDF file detected
Setting c

In [65]:
latex

'\n        \\subsection*{ data/24-Jun PD49 24-Jun-D2 }\n        \\begin{center}\n        \\begin{adjustbox}{width=1\\textwidth}\n        \\includegraphics[page=1]{figs.pdf}\n        \\end{adjustbox}\n        \\end{center}\n        \\large\\textbf{Length (epochs): 25547.000}\\\\\n        \\textbf{Length (hours): 70.964}\\\\\n        \\textbf{Average Network Confidence: 0.935}\\\\\n        \\textbf{Paradoxical Proportion: 0.073}\\\\\n        \\textbf{Slow-Wave Proportion: 0.358}\\\\\n        \\textbf{Wakefulness Proportion: 0.569}\\\\\n        \n        \\subsection*{ data/24-Jun PD49 24-Jun-A1 }\n        \\begin{center}\n        \\begin{adjustbox}{width=1\\textwidth}\n        \\includegraphics[page=2]{figs.pdf}\n        \\end{adjustbox}\n        \\end{center}\n        \\large\\textbf{Length (epochs): 15721.000}\\\\\n        \\textbf{Length (hours): 43.669}\\\\\n        \\textbf{Average Network Confidence: 0.882}\\\\\n        \\textbf{Paradoxical Proportion: 0.022}\\\\\n        \\textbf{