In [7]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


#### Preprocessing Details


Before running this notebook, please preprocess your PSG files using the scripts provided in `sleepfm/preprocessing`. Note that PSG recordings may contain different sets of channels across datasets. The predefined channel–modality mappings used in this project are specified in `sleepfm/configs/channel_groups.json`.

Although we have attempted to make this mapping as comprehensive as possible, we strongly recommend reviewing the channels present in your specific PSG data. In consultation with domain experts, you should group any additional or dataset-specific channels into the appropriate modality categories and update `channel_groups.json` accordingly. This step is critical to ensure that all channels are correctly aligned with their intended modalities during preprocessing and downstream modeling.

In [110]:
import torch
from torch import nn
import numpy as np
import os
import tqdm
import random
import sys
sys.path.append("..")
sys.path.append("../sleepfm")
import pandas as pd
from models.dataset import SetTransformerDataset, collate_fn
from models.models import SetTransformer, SleepEventLSTMClassifier, DiagnosisFinetuneFullLSTMCOXPHWithDemo
import h5py
from utils import load_config, load_data, save_data, count_parameters
from torch.utils.data import Dataset, DataLoader

#### Part 1: Generating embeddings from SleepFM pretrained model

Note: This is just a demo notebook which shows generating embedding for 1 demno PSG. To see full script, please check `sleepfm/pipeline/generate_embeddings.py`. 

In [9]:
model_path = "../sleepfm/checkpoints/model_base"
channel_groups_path = "../sleepfm/configs/channel_groups.json"
config_path = os.path.join(model_path, "config.json")

config = load_config(config_path)
channel_groups = load_data(channel_groups_path)

In [11]:
modality_types = config["modality_types"]
in_channels = config["in_channels"]
patch_size = config["patch_size"]
embed_dim = config["embed_dim"]
num_heads = config["num_heads"]
num_layers = config["num_layers"]
pooling_head = config["pooling_head"]
dropout = 0.0

In [12]:
model_class = getattr(sys.modules[__name__], config['model'])
model = model_class(in_channels, patch_size, embed_dim, num_heads, num_layers, pooling_head=pooling_head, dropout=dropout)

device = torch.device("cuda")
if device.type == "cuda":
    model = torch.nn.DataParallel(model)

model.to(device)
total_layers, total_params = count_parameters(model)
print(f'Trainable parameters: {total_params / 1e6:.2f} million')
print(f'Number of layers: {total_layers}')

Trainable parameters: 4.44 million
Number of layers: 93


In [13]:
checkpoint = torch.load(os.path.join(model_path, "best.pt"))
model.load_state_dict(checkpoint["state_dict"])
model.eval()

DataParallel(
  (module): SetTransformer(
    (patch_embedding): Tokenizer(
      (tokenizer): Sequential(
        (0): Conv1d(1, 4, kernel_size=(5,), stride=(2,), padding=(2,))
        (1): BatchNorm1d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ELU(alpha=1.0)
        (3): LayerNorm((4, 320), eps=1e-05, elementwise_affine=True)
        (4): Conv1d(4, 8, kernel_size=(5,), stride=(2,), padding=(2,))
        (5): BatchNorm1d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (6): ELU(alpha=1.0)
        (7): LayerNorm((8, 160), eps=1e-05, elementwise_affine=True)
        (8): Conv1d(8, 16, kernel_size=(5,), stride=(2,), padding=(2,))
        (9): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (10): ELU(alpha=1.0)
        (11): LayerNorm((16, 80), eps=1e-05, elementwise_affine=True)
        (12): Conv1d(16, 32, kernel_size=(5,), stride=(2,), padding=(2,))
        (13): BatchNorm1d(32, eps=1

In [14]:
hdf5_paths = ["demo_psg.hdf5"]
dataset = SetTransformerDataset(config, channel_groups, hdf5_paths=hdf5_paths, split="test")

dataloader = torch.utils.data.DataLoader(dataset, 
                                            batch_size=16, 
                                            num_workers=1, 
                                            shuffle=False, 
                                            collate_fn=collate_fn)

Indexing files: 100%|██████████| 1/1 [00:00<00:00,  4.44it/s]


In [15]:
output = "demo_emb"
output_5min_agg = f"demo_5min_agg_emb"
os.makedirs(output, exist_ok=True)
os.makedirs(output_5min_agg, exist_ok=True)

In [71]:
with torch.no_grad():
    with tqdm.tqdm(total=len(dataloader)) as pbar:
        for batch in dataloader:
            batch_data, mask_list, file_paths, dset_names_list, chunk_starts = batch
            (bas, resp, ekg, emg) = batch_data
            (mask_bas, mask_resp, mask_ekg, mask_emg) = mask_list

            bas = bas.to(device, dtype=torch.float)
            resp = resp.to(device, dtype=torch.float)
            ekg = ekg.to(device, dtype=torch.float)
            emg = emg.to(device, dtype=torch.float)

            mask_bas = mask_bas.to(device, dtype=torch.bool)
            mask_resp = mask_resp.to(device, dtype=torch.bool)
            mask_ekg = mask_ekg.to(device, dtype=torch.bool)
            mask_emg = mask_emg.to(device, dtype=torch.bool)

            embeddings = [
                model(bas, mask_bas),
                model(resp, mask_resp),
                model(ekg, mask_ekg),
                model(emg, mask_emg),
            ]

            # Model gives two kinds of embeddings. Granular 5 second-level embeddings and aggregated 5 minute-level embeddings. We save both of them below. 

            embeddings_new = [e[0].unsqueeze(1) for e in embeddings]

            for i in range(len(file_paths)):
                file_path = file_paths[i]
                chunk_start = chunk_starts[i]
                subject_id = os.path.basename(file_path).split('.')[0]
                output_path = os.path.join(output_5min_agg, f"{subject_id}.hdf5")

                with h5py.File(output_path, 'a') as hdf5_file:
                    for modality_idx, modality_type in enumerate(config["modality_types"]):
                        if modality_type in hdf5_file:
                            dset = hdf5_file[modality_type]
                            chunk_start_correct = chunk_start // (embed_dim * 5 * 60)
                            chunk_end = chunk_start_correct + embeddings_new[modality_idx][i].shape[0]
                            if dset.shape[0] < chunk_end:
                                dset.resize((chunk_end,) + embeddings_new[modality_idx][i].shape[1:])
                            dset[chunk_start_correct:chunk_end] = embeddings_new[modality_idx][i].cpu().numpy()
                        else:
                            hdf5_file.create_dataset(modality_type, data=embeddings_new[modality_idx][i].cpu().numpy(), chunks=(embed_dim,) + embeddings_new[modality_idx][i].shape[1:], maxshape=(None,) + embeddings_new[modality_idx][i].shape[1:])

            embeddings_new = [e[1] for e in embeddings]

            for i in range(len(file_paths)):
                file_path = file_paths[i]
                chunk_start = chunk_starts[i]
                subject_id = os.path.basename(file_path).split('.')[0]
                output_path = os.path.join(output, f"{subject_id}.hdf5")

                with h5py.File(output_path, 'a') as hdf5_file:
                    for modality_idx, modality_type in enumerate(config["modality_types"]):
                        if modality_type in hdf5_file:
                            dset = hdf5_file[modality_type]
                            chunk_start_correct = chunk_start // (embed_dim * 5)
                            chunk_end = chunk_start_correct + embeddings_new[modality_idx][i].shape[0]
                            if dset.shape[0] < chunk_end:
                                dset.resize((chunk_end,) + embeddings_new[modality_idx][i].shape[1:])
                            dset[chunk_start_correct:chunk_end] = embeddings_new[modality_idx][i].cpu().numpy()
                        else:
                            hdf5_file.create_dataset(modality_type, data=embeddings_new[modality_idx][i].cpu().numpy(), chunks=(embed_dim,) + embeddings_new[modality_idx][i].shape[1:], maxshape=(None,) + embeddings_new[modality_idx][i].shape[1:])
            pbar.update()

100%|██████████| 8/8 [00:17<00:00,  2.20s/it]


#### Sleep Staging

Note that below, we are using our finetuned sleep staging model. It is always a good idea to finetune our model on your specific data, even if you only have a handful of sample, so that the model can adapt to your specific data distribution. Script to finetune your sleep staging model head is given in `sleepfm/pipeline/finetune_sleep_staging.py`. 

In [16]:
sleep_staging_model_path = "../sleepfm/checkpoints/model_sleep_staging"
sleep_staging_config = load_data(os.path.join(sleep_staging_model_path, "config.json"))

sleep_staging_model_params = sleep_staging_config['model_params']
sleep_staging_model_class = getattr(sys.modules[__name__], sleep_staging_config['model'])

sleep_staging_model = sleep_staging_model_class(**sleep_staging_model_params).to(device)
sleep_staging_model_name = type(sleep_staging_model).__name__

In [17]:
sleep_staging_model = nn.DataParallel(sleep_staging_model)
print(f"Using {torch.cuda.device_count()} GPUs")

Using 1 GPUs


In [18]:
print(f"Model initialized: {sleep_staging_model_name}")
total_layers, total_params = count_parameters(sleep_staging_model)
print(f'Trainable parameters: {total_params / 1e6:.2f} million')
print(f'Number of layers: {total_layers}')

Model initialized: SleepEventLSTMClassifier
Trainable parameters: 1.19 million
Number of layers: 20


In [19]:
sleep_staging_checkpoint_path = os.path.join(sleep_staging_model_path, "best.pth")
sleep_staging_checkpoint = torch.load(sleep_staging_checkpoint_path)
sleep_staging_model.load_state_dict(sleep_staging_checkpoint)

<All keys matched successfully>

Below are some helper functions for loading data for sleep staging. You can find similar functions within `sleepfm/models/dataset.py`. You may need to modify it slightly based on your usecase. 

In [24]:
class SleepEventClassificationDataset(Dataset):
    def __init__(self,
                 config,
                 channel_groups,
                 hdf5_paths=None,
                 split="train"):

        self.config = config
        self.max_channels = self.config["max_channels"]
        self.context = int(self.config["context"])
        self.channel_like = self.config["channel_like"]

        # ---- Resolve HDF5 paths (simple rule) ----
        # If hdf5_paths is provided, use it. Otherwise, load from config["split_path"].
        if hdf5_paths:
            hdf5_paths = [p for p in hdf5_paths if os.path.exists(p)]
        else:
            data_path = config["data_path"]
            split_paths = load_data(config["split_path"])[split]
            hdf5_paths = []
            for rel_path in split_paths:
                abs_path = os.path.join(data_path, rel_path)
                if os.path.exists(abs_path):
                    hdf5_paths.append(abs_path)

        # Optional truncation
        if config.get("max_files"):
            hdf5_paths = hdf5_paths[:config["max_files"]]

        # ---- Build index map ----
        if self.context == -1:
            self.index_map = [(path, -1) for path in hdf5_paths]
        else:
            self.index_map = []
            loop = tqdm(hdf5_paths, total=len(hdf5_paths), desc=f"Indexing {split} data")
            for hdf5_file_path in loop:
                try:
                    with h5py.File(hdf5_file_path, "r") as file:
                        dset_names = list(file.keys())
                        if len(dset_names) == 0:
                            continue
                        # Use the first dataset to infer length (matches your original logic)
                        dset0 = dset_names[0]
                        dataset_length = file[dset0].shape[0]
                        for i in range(0, dataset_length, self.context):
                            self.index_map.append((hdf5_file_path, i))
                except OSError:
                    # Corrupt/unreadable file; skip
                    continue

        print(f"Number of files in {split} set: {len(hdf5_paths)}")
        print(f"Number of segments to be processed in {split} set: {len(self.index_map)}")

        self.total_len = len(self.index_map)
        self.max_seq_len = config["model_params"]["max_seq_length"]

        if self.total_len == 0:
            raise ValueError(f"No valid samples found for split='{split}'. Check paths/config.")

    def __len__(self):
        return self.total_len

    def get_index_map(self):
        return self.index_map

    def __getitem__(self, idx):
        hdf5_path, start_index = self.index_map[idx]

        x_data = []
        try:
            with h5py.File(hdf5_path, "r") as hf:
                dset_names = list(hf.keys())
                for dataset_name in dset_names:
                    if dataset_name in self.channel_like:
                        if self.context == -1:
                            x_data.append(hf[dataset_name][:])
                        else:
                            x_data_in = hf[dataset_name][start_index:start_index + self.context]
                            x_data.append(x_data_in)
        except OSError:
            # If file can't be read, skip to next example
            return self.__getitem__((idx + 1) % self.total_len)

        if not x_data:
            # Skip this data point if x_data is empty
            return self.__getitem__((idx + 1) % self.total_len)

        # Convert list -> array -> tensor (keeps your original behavior)
        x_data = np.array(x_data)
        x_data = torch.tensor(x_data, dtype=torch.float32)

        return x_data, self.max_channels, self.max_seq_len, hdf5_path


def sleep_event_finetune_full_collate_fn(batch):
    x_data, max_channels_list, max_seq_len_list, hdf5_path_list = zip(*batch)

    num_channels = max(max_channels_list)

    max_seq_len_temp = max([item.size(1) for item in x_data])

    # Determine the max sequence length for padding
    if max_seq_len_list[0] is None:
        max_seq_len = max_seq_len_temp
    else:
        max_seq_len = min(max_seq_len_temp, max_seq_len_list[0])

    padded_x_data = []
    padded_mask = []

    for x_item in x_data:
        # Get the shape of x_item
        c, s, e = x_item.size()
        c = min(c, num_channels)
        s = min(s, max_seq_len)  # Ensure the sequence length doesn't exceed max_seq_len

        # Create a padded tensor and a mask tensor for x_data
        padded_x_item = torch.zeros((num_channels, max_seq_len, e))
        mask = torch.ones((num_channels, max_seq_len))

        # Copy the actual data to the padded tensor and set the mask for real data
        padded_x_item[:c, :s, :e] = x_item[:c, :s, :e]
        mask[:c, :s] = 0  # 0 for real data, 1 for padding

        padded_x_data.append(padded_x_item)
        padded_mask.append(mask)

    # Stack all tensors into a batch
    x_data = torch.stack(padded_x_data)
    padded_mask = torch.stack(padded_mask)
    
    return x_data, padded_mask, hdf5_path_list

In [25]:
hdf5_paths = ["demo_emb/demo_psg.hdf5"]
test_dataset = SleepEventClassificationDataset(sleep_staging_config, channel_groups, split="test", hdf5_paths=hdf5_paths)

Number of files in test set: 1
Number of segments to be processed in test set: 1


In [26]:
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=1, collate_fn=sleep_event_finetune_full_collate_fn)

In [39]:
# Validation loop at the end of each epoch
model.eval()
val_loss = 0.0
all_targets = []
all_logits = []
all_outputs = []
all_masks = []
all_paths = []

count = 0
with torch.no_grad():
    for (x_data, padded_matrix, hdf5_path_list) in tqdm.tqdm(test_loader, desc="Evaluating"):
        x_data, padded_matrix, hdf5_path_list = x_data.to(device), padded_matrix.to(device), list(hdf5_path_list)
        outputs, mask = sleep_staging_model(x_data, padded_matrix)
        all_outputs.append(torch.softmax(outputs, dim=-1).cpu().numpy())
        all_logits.append(outputs.cpu().numpy())
        all_masks.append(mask.cpu().numpy())
        all_paths.append(hdf5_path_list)


save_path = "demo_sleep_staging"
os.makedirs(save_path, exist_ok=True)

outputs_path = os.path.join(save_path, "all_outputs.pickle")
logits_path = os.path.join(save_path, "all_logits.pickle")
mask_path = os.path.join(save_path, "all_masks.pickle")
file_paths = os.path.join(save_path, "all_paths.pickle")

save_data(all_outputs, outputs_path)
save_data(all_logits, logits_path)
save_data(all_masks, mask_path)
save_data(all_paths, file_paths)

Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.64it/s]


In [43]:
all_outputs[0].shape

(1, 7560, 5)

Now you have the logits and outputs, that you can then use to do sleep staging. 

#### Disease Prediction

In [70]:
disease_model_path = "../sleepfm/checkpoints/model_diagnosis"
















config = load_data(os.path.join(disease_model_path, "config.json"))

In [71]:
config["model_params"]["dropout"] = 0.0
model_params = config['model_params']
model_class = getattr(sys.modules[__name__], config['model'])
model = model_class(**model_params).to(device)
model_name = type(model).__name__

In [72]:
model = nn.DataParallel(model)
print(f"Model initialized: {model_name}")
total_layers, total_params = count_parameters(model)
print(f'Trainable parameters: {total_params / 1e6:.2f} million')
print(f'Number of layers: {total_layers}')

Model initialized: DiagnosisFinetuneFullLSTMCOXPHWithDemo
Trainable parameters: 0.91 million
Number of layers: 15


In [73]:
checkpoint_path = os.path.join(disease_model_path, "best.pth")

In [74]:
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint)

<All keys matched successfully>

In [105]:
class DiagnosisFinetuneFullCOXPHWithDemoDataset(Dataset):
    def __init__(self, 
                 config,
                 channel_groups,
                 hdf5_paths=None,
                 demo_labels_path=None, 
                 split="train"):

        self.config = config
        self.channel_groups = channel_groups
        self.max_channels = self.config["max_channels"]

        # --- Load demographic features ---
        if not demo_labels_path:
            demo_labels_path = config["demo_labels_path"]

        demo_labels_df = pd.read_csv(demo_labels_path)
        demo_labels_df = demo_labels_df.set_index("Study ID")
        study_ids = set(demo_labels_df.index)
        print(study_ids)

        # --- Resolve HDF5 paths (explicit precedence) ---
        if hdf5_paths:
            # Use provided paths directly
            hdf5_paths = [f for f in hdf5_paths if os.path.exists(f)]
        else:
            # Load from split file
            split_paths = load_data(config["split_path"])[split]
            hdf5_paths = [f for f in split_paths if os.path.exists(f)]

        # Filter by available demo labels
        hdf5_paths = [
            f for f in hdf5_paths
            if os.path.basename(f).split(".")[0] in study_ids
        ]

        # Optional truncation
        if config.get("max_files"):
            hdf5_paths = hdf5_paths[:config["max_files"]]

        # --- Build labels dict (demo only) ---
        labels_dict = {}
        for study_id in tqdm.tqdm(study_ids, desc="Loading demo features"):
            labels_dict[study_id] = {
                "demo_feats": list(demo_labels_df.loc[study_id].values)
            }

        # --- Build index map ---
        self.index_map = [
            (path, labels_dict[os.path.basename(path).split(".")[0]])
            for path in hdf5_paths
        ]

        print(f"Number of files in {split} set: {len(hdf5_paths)}")
        print(f"Number of files to be processed in {split} set: {len(self.index_map)}")

        self.total_len = len(self.index_map)
        self.max_seq_len = config["model_params"]["max_seq_length"]

        if self.total_len == 0:
            raise ValueError(f"No valid HDF5 files found for split='{split}'.")

    def __len__(self):
        return self.total_len

    def __getitem__(self, idx):
        hdf5_path, label_dict = self.index_map[idx]

        demo_feats = label_dict["demo_feats"]

        x_data = []
        with h5py.File(hdf5_path, "r") as hf:
            dset_names = [
                dset_name for dset_name in hf.keys()
                if isinstance(hf[dset_name], h5py.Dataset)
                and dset_name in self.config["modality_types"]
            ]

            random.shuffle(dset_names)
            for dataset_name in dset_names:
                x_data.append(hf[dataset_name][:])

        if not x_data:
            # Skip empty sample
            return self.__getitem__((idx + 1) % self.total_len)

        # Convert to tensor
        x_data = torch.tensor(np.array(x_data), dtype=torch.float32)
        demo_feats = torch.tensor(demo_feats, dtype=torch.float32)

        return x_data, demo_feats, self.max_channels, self.max_seq_len, hdf5_path


def diagnosis_finetune_full_coxph_with_demo_collate_fn(batch):
    x_data, demo_feats, max_channels_list, max_seq_len_list, hdf5_path_list = zip(*batch)

    num_channels = max(max_channels_list)

    if max_seq_len_list[0] == None:
        max_seq_len = max([item.size(1) for item in x_data])
    else:
        max_seq_len = max_seq_len_list[0]

    padded_x_data = []
    padded_mask = []
    for item in x_data:
        c, s, e = item.size()
        c = min(c, num_channels)
        s = min(s, max_seq_len)  # Ensure the sequence length doesn't exceed max_seq_len

        # Create a padded tensor and a mask tensor
        padded_item = torch.zeros((num_channels, max_seq_len, e))
        mask = torch.ones((num_channels, max_seq_len))

        # Copy the actual data to the padded tensor and set the mask for real data
        padded_item[:c, :s, :e] = item[:c, :s, :e]
        mask[:c, :s] = 0  # 0 for real data, 1 for padding

        padded_x_data.append(padded_item)
        padded_mask.append(mask)

    # Stack all tensors into a batch
    x_data = torch.stack(padded_x_data)
    demo_feats = torch.stack(demo_feats)
    padded_mask = torch.stack(padded_mask)

    return x_data, demo_feats, padded_mask, hdf5_path_list

In [106]:
save_path = "demo_diagnosis"
os.makedirs(save_path, exist_ok=True)

In [107]:
hdf5_paths = ["demo_emb/demo_psg.hdf5"]
demo_labels_path = "demo_age_gender.csv"
test_dataset = DiagnosisFinetuneFullCOXPHWithDemoDataset(config, channel_groups, split="test", hdf5_paths=hdf5_paths, demo_labels_path=demo_labels_path)

{'demo_psg'}


Loading demo features: 100%|██████████| 1/1 [00:00<00:00, 7436.71it/s]

Number of files in test set: 1
Number of files to be processed in test set: 1





In [108]:
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=1, collate_fn=diagnosis_finetune_full_coxph_with_demo_collate_fn)

In [111]:
model.eval()
all_event_times = []
all_is_event = []
all_outputs = []
all_paths = []

with torch.no_grad():
    for item in tqdm.tqdm(test_loader, desc="Evaluating"):
        x_data, demo_feats, padded_matrix, hdf5_path_list = item
        x_data, demo_feats, padded_matrix, hdf5_path_list = x_data.to(device), demo_feats.to(device), padded_matrix.to(device), list(hdf5_path_list)
        outputs = model(x_data, padded_matrix, demo_feats)
    
        logits = outputs.cpu().numpy()
        all_outputs.append(logits)
        all_paths.append(hdf5_path_list)

all_outputs = np.concatenate(all_outputs, axis=0)
all_paths = np.concatenate(all_paths)

outputs_path = os.path.join(save_path, "all_outputs.pickle")
file_paths = os.path.join(save_path, "all_paths.pickle")

save_data(all_outputs, outputs_path)
save_data(all_paths, file_paths)

Evaluating:   0%|          | 0/1 [00:00<?, ?it/s]

  return torch._transformer_encoder_layer_fwd(
Evaluating: 100%|██████████| 1/1 [00:00<00:00,  2.34it/s]


In [112]:
all_outputs.shape

(1, 1065)

Above, you get the model outputs, which you can then use to look for specific disease diagnosis. Nope that the shape of the output above is 1065, meaning, this model gives logprobs for 1065 conditions. We provide information about each disease index and its corresponding phecode here `sleepfm/configs/label_mapping.csv`. You can map it as follows. 

In [136]:
labels_df = pd.read_csv("../sleepfm/configs/label_mapping.csv")

In [138]:
labels_df["output"] = all_outputs[0]

In [139]:
labels_df.head()

Unnamed: 0,label_idx,phecode,phenotype,output
0,0,8.0,Intestinal infection,0.857936
1,1,8.5,Bacterial enteritis,0.774784
2,2,8.6,Viral Enteritis,1.054985
3,3,38.0,Septicemia,1.575306
4,4,38.3,Bacteremia,1.663425
