In [None]:
import json
from pathlib import Path

import h5py
import torch
from torch.utils.data import Dataset
import pandas as pd

project_dir = Path("../").resolve()

class WSIDataset(Dataset):
    def __init__(self, hdf5_filepath, wsi_ids, labels):
        self.hdf5_filepath = hdf5_filepath
        self.wsi_ids = wsi_ids
        self.labels = labels

    def __len__(self):
        return len(self.wsi_ids)

    def __getitem__(self, idx):
        wsi_id = self.wsi_ids[idx]
        label = self.labels[idx]

        # Load embeddings for the WSI
        with h5py.File(self.hdf5_filepath, 'r') as hdf5_file:
            embeddings = torch.tensor(hdf5_file['embeddings'][wsi_id][:])

        return wsi_id, embeddings, label

In [None]:
def load_metadata():
    """Load WSI metadata and fold information for cross-validation."""
    fold_df = pd.read_csv(project_dir / "data/interim/tcga_folds.csv")
    with open(project_dir / "data/interim/tcga_wsi_data.json") as f:
        wsi_metadata = json.load(f)
    return wsi_metadata, fold_df



In [None]:
wsi_metadata, fold_df = load_metadata()

In [None]:
wsi_metadata