In [2]:
import os
import h5py
import scanpy as sc
import numpy as np
import pandas as pd
import torch
import torchvision.transforms as transforms
from PIL import Image
from scipy.spatial import Delaunay

# Custom dataset class for DLPFC
class DLPFC_Dataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, batch_id, transform=None):
        """
        Args:
            data_dir (str): Path to the dataset directory.
            batch_id (str): Batch folder name (e.g., "batch1").
            transform: Optional transformations for images.
        """
        self.data_dir = os.path.join(data_dir, batch_id)
        self.transform = transform

        # Load gene expression data
        self.gene_exp = self.load_gene_expression()

        # Load metadata
        self.metadata = pd.read_csv(os.path.join(self.data_dir, "metadata.tsv"), sep="\t")

        # Load spatial coordinates
        self.spatial_info = pd.read_csv(os.path.join(self.data_dir, "spatial", "tissue_positions_list.csv"), header=None)
        self.spatial_info.columns = ["spot_id", "in_tissue", "x", "y", "pixel_x", "pixel_y"]

        # Load high-resolution image
        img_path = os.path.join(self.data_dir, "spatial", "tissue_hires_image.png")
        self.image = Image.open(img_path).convert("RGB")

        # Compute adjacency matrix using Delaunay triangulation
        self.adj_matrix = self.compute_adjacency_matrix()

    def load_gene_expression(self):
        """Load the filtered_feature_bc_matrix.h5 file."""
        h5_path = os.path.join(self.data_dir, "filtered_feature_bc_matrix.h5")
        with h5py.File(h5_path, "r") as f:
            genes = [x.decode("utf-8") for x in f["matrix"]["features"]["name"]]
            barcodes = [x.decode("utf-8") for x in f["matrix"]["barcodes"]]
            matrix = f["matrix"]["data"][:]
            indices = f["matrix"]["indices"][:]
            indptr = f["matrix"]["indptr"][:]
            shape = f["matrix"]["shape"][:]

        # Convert sparse matrix to dense format
        from scipy.sparse import csr_matrix
        gene_exp_matrix = csr_matrix((matrix, indices, indptr), shape=shape).toarray()
        return pd.DataFrame(gene_exp_matrix.T, index=barcodes, columns=genes)

    def compute_adjacency_matrix(self):
        """Compute adjacency matrix using Delaunay triangulation."""
        points = self.spatial_info[["pixel_x", "pixel_y"]].values
        tri = Delaunay(points)
        adjacency_matrix = np.zeros((len(points), len(points)))

        for simplex in tri.simplices:
            for i in range(3):
                for j in range(i + 1, 3):
                    adjacency_matrix[simplex[i], simplex[j]] = 1
                    adjacency_matrix[simplex[j], simplex[i]] = 1
        return adjacency_matrix

    def __len__(self):
        return len(self.spatial_info)

    def __getitem__(self, idx):
        spot_info = self.spatial_info.iloc[idx]
        spot_id = spot_info["spot_id"]
        x, y = int(spot_info["pixel_x"]), int(spot_info["pixel_y"])

        # Extract gene expression for this spot
        if spot_id in self.gene_exp.index:
            gene_values = torch.tensor(self.gene_exp.loc[spot_id].values, dtype=torch.float32)
        else:
            gene_values = torch.zeros(len(self.gene_exp.columns), dtype=torch.float32)

        # Extract image patch centered at (x, y)
        patch_size = 224  # Change as needed
        left, upper = max(0, x - patch_size // 2), max(0, y - patch_size // 2)
        right, lower = left + patch_size, upper + patch_size
        image_patch = self.image.crop((left, upper, right, lower))

        if self.transform:
            image_patch = self.transform(image_patch)

        return image_patch, gene_values, torch.tensor([x, y], dtype=torch.float32)

# Example Usage
if __name__ == "__main__":
    dataset = DLPFC_Dataset(data_dir="/home/lytq/Hist2ST/data/st_data/DLPFC_new", batch_id="151673", transform=transforms.ToTensor())
    image_patch, gene_values, spatial_coords = dataset[0]

    print("Image Patch Shape:", image_patch.shape)
    print("Gene Expression Shape:", gene_values.shape)
    print("Spatial Coordinates:", spatial_coords)


ValueError: index pointer size (3640) should be (33539)