In [None]:
import json
from pathlib import Path

import h5py
import torch
from torch.utils.data import Dataset
import pandas as pd
from torch.utils.data import DataLoader

project_dir = Path("../").resolve()

In [None]:
class WSIDataset(Dataset):

    def __init__(self, hdf5_filepath, wsi_ids, labels):
        self.hdf5_filepath = hdf5_filepath
        self.wsi_ids = wsi_ids
        self.labels = labels

    def __len__(self):
        return len(self.wsi_ids)

    def __getitem__(self, idx):
        wsi_id = self.wsi_ids[idx]
        label = self.labels[idx]

        # Load embeddings for the WSI
        with h5py.File(self.hdf5_filepath, 'r') as hdf5_file:
            embeddings = torch.tensor(hdf5_file['embeddings'][wsi_id][:])

        return wsi_id, embeddings, label

In [None]:
def load_metadata():
    """Load WSI metadata and fold information for cross-validation."""
    fold_df = pd.read_csv(project_dir / "data/interim/tcga_folds.csv")
    with open(project_dir / "data/interim/tcga_wsi_data.json") as f:
        wsi_metadata = json.load(f)
    return wsi_metadata, fold_df


In [None]:
wsi_metadata, fold_df = load_metadata()

In [None]:
wsi_ids = [k["wsi_id"] for k in wsi_metadata]

In [None]:
dataset = WSIDataset("/home/valentin/workspaces/histolung/data/embeddings/uni_embeddings.h5", wsi_ids[:10], [1 for _ in range(10)])

In [None]:
def collate_fn_ragged(batch):
    wsi_ids, embeddings, labels = zip(*batch)
    return list(wsi_ids), list(embeddings), torch.tensor(labels)

In [None]:
data_loader = DataLoader(dataset, batch_size=5, collate_fn=collate_fn_ragged)

In [None]:
for batch in data_loader:
    wsi_ids, embeddings, labels = batch
    for i, wsi_id in enumerate(wsi_ids):
        print(f"{wsi_id}'s embedding has a shape of: {embeddings[i].shape}")

In [None]:

def inspect_hdf5(file_path):
    with h5py.File(file_path, "r") as f:
        def visit(name, obj):
            if isinstance(obj, h5py.Dataset):  # Check if it's a dataset
                print(f"Dataset: {name}")
                print(f"  Compression: {obj.compression}")
                print(f"  Compression options: {obj.compression_opts}")
            elif isinstance(obj, h5py.Group):  # Check if it's a group
                print(f"Group: {name}")
        
        # Recursively visit all items in the HDF5 file
        f.visititems(visit)

# Example usage
inspect_hdf5("/home/valentin/workspaces/histolung/data/embeddings/uni_embeddings.h5")