To perform the subgraph sampling training in `Hist2Cell`, we need to prepare the dataset into certain data structure.

In `./example_data/humanlung_cell2location` and `./example_data/humanlung_cell2location_2x`, we provide the processed data for the humanlung cell2location dataset in our study.
In this tutorial, we will go through the data structure of the provided processed data.

Then, we will show how to prepare your own raw data into this data structure.

First, let's see the data structure of process data for slide `WSA_LngSP9258467` from donor A50:

In [1]:
import torch

processed_data = torch.load("../example_data/humanlung_cell2location/WSA_LngSP9258467.pt")
processed_data

Data(x=[422, 3, 224, 224], edge_index=[2, 2732], y=[422, 330], pos=[422, 2])

In the processed  data, the following varibales are:
- `x`: the 224*224*3 image patch for each spot in the ST data, each spot is considered as one node in the graphs representing the slide, there are in total 422 spots in this slide;
- `edge_index`: graph connectivity in `COO` format with shape `[2, num_edges]`, there are 2732 edges in this slide;
- `y`: the label for each spot, contains 250 highly expressed gene labels + 80 fine-grained cell abundance labels, resulting 330 labels for each spot;
- `pos`: the x-y pixel coordinate of each spot on the original slide, used in visualization and calculating cell-colocalization metric;

In [2]:
processed_data['x'][0].shape

torch.Size([3, 224, 224])

In [3]:
processed_data['y'][0][:5]

tensor([2.3157, 2.3157, 2.3157, 4.5253, 3.3463])

In [11]:
processed_data['edge_index']

tensor([[  0,   0,   0,  ..., 421, 421, 421],
        [  0,  54,  96,  ..., 250, 321, 421]])

For the `DataLoader`, we use `NeighborLoader` from `torch_geometric` , which supports subgraph sampling from some center nodes, here are 2 important parameters:
- `hop`: this parameter define receptive field when sampling the subgraphs with a group of center nodes for training/testing, in our paper, we use 2-hop subgraphs to achieve a banlance between computation cost and performance, generally, bigger receptive field will contain more neighboring information.
- `subgraph_bs`: this parameter define the number of subgraphs to be sampled during training/testing, which is the `subgraph batchsize`, we use `subgraph_bs=16` on our RTX 3090 GPU.

In [1]:
from torch_geometric.loader import NeighborLoader
import torch_geometric
torch_geometric.typing.WITH_PYG_LIB = False


hop = 2
subgraph_bs = 16

dataloader_loader = NeighborLoader(
    processed_data,
    num_neighbors=[-1]*hop,
    batch_size=subgraph_bs,
    directed=False,
    input_nodes=None,
    shuffle=True,
    num_workers=2,
)

NameError: name 'processed_data' is not defined

In `torch_geometric`, the sampled subgraphs are merged into a big graph for parallel training, for more details, please refer to the documentation of torch_geometric`.

In [8]:
for subgraphs in dataloader_loader:
    print(subgraphs)
    break

Data(x=[209, 3, 224, 224], edge_index=[2, 1199], y=[209, 330], pos=[209, 2], n_id=[209], e_id=[1199], input_id=[16], batch_size=16)


Next, we show how to preprocess raw data into this structure for `Hist2Cell` training:

We upload the raw data of slide `WSA_LngSP9258467` in `./example_data/example_raw_data/WSA_LngSP9258467` for preprocessing tutorial.

We explain the files under this folder one by one:
- `patch`: folder contain the image patches for each spot in the ST data, to crop the image patches from a WSI and obtain the coordinate, please refer to the pipeline in [DSMIL repository](https://github.com/binli123/dsmil-wsi);
- `cell_ratio.csv`: contains the 80 fine-grained cell type abundances;
- `stdata.csv`: contains the spatial gene expression for each spot;
- `log1p_stdata.csv`: contains the log1p processed spatial gene expression for each spot;
- `spots.csv`: contains the pixel coordinate for each spot;
- `high_250_stdata.csv`: contains the top 250 highly expressed spatial gene expression for each spot;
- `high_250_stdata_log1p.csv`: contains the log1p processed top 250 highly expressed spatial gene expression for each spot;
- `WSA_LngSP9258467.jpg`: the original slide image;
- `WSA_LngSP9258467_low_res.jpg`: the low-resolution slide image, for quick visualization and other processing;
- `spot_view.jpg`: the original slide image with visualized spots;
- `2x_patch`: folder contain the image patches for each spot in 2x super-resolved experiments;
- `2x_spots.csv`: contains the pixel coordinate for each image patch in the 2x super-resolved experiments;
- `2x_spot_view.jpg`: the original slide image with visualized 2x resolution spots;

To process the raw data, we first define a `STDataset` to iterate the spots:

In [14]:
import os
import numpy as np
import pandas as pd
from torchvision import transforms
from PIL import Image
import torch


class STDataset(torch.utils.data.Dataset):
    def __init__(self, root, slide,transform=None):
        super(STDataset, self).__init__()
        self.root = root
        self.slide = slide
        self.transform = transform

        patch_path = os.path.join(root, slide, 'patches')
        patch = os.listdir(patch_path)
        patch_list = [x.split('.')[0] for x in patch]

        cell_label = pd.read_csv(os.path.join(root, slide, 'visium_ductal_breast_ground_truth-cell2locationPopulations_aligned.csv'), index_col=0)
        gene_label = pd.read_csv(os.path.join(root, slide, 'raw_counts_df.csv'), index_col=0)
        label_df = pd.merge(gene_label, cell_label, left_index=True, right_index=True)
        

        label_index_set = set(label_df.index)
        patch_index_set = set(patch_list)
        and_set = label_index_set & patch_index_set

        patch_list = list(and_set)
        self.label_df = label_df.loc[patch_list]
        self.patch = patch_list
        
        


    def __getitem__(self, index):
        patch_id = self.patch[index]
        patch_path = os.path.join(self.root, self.slide, 'patches', patch_id)
        patch = Image.open(patch_path+'.png').convert('RGB')
        data = transforms.Resize((20, 20))(patch)
        if self.transform is not None:
            data = self.transform(data)
        label = self.label_df.loc[patch_id].values
        label = torch.Tensor(label)

        return patch_id, data, label

    def __len__(self):
        return len(self.patch)

In [15]:
test_transform_pcam = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.75420076, 0.6217488,  0.6946153], std=[0.11452981, 0.15589917, 0.14246847])
    ])

test_data = STDataset(root="../patch", slide="output",transform=test_transform_pcam)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=512, shuffle=False, num_workers=0)

We iterate this `STDataset`, save the spot image patch, the spot labels, and the spot id:

In [16]:
spot_data_array = []
spot_label_array = []
spot_id_array = []
for name, data, label in test_loader:
    spot_id_array.append(list(name))
    label = label.float()
    label = label.squeeze()
    spot_label_array.append(label.detach().numpy())
    spot_data_array.append(data.detach().numpy())
for i in range(len(spot_data_array)):
    if len(spot_data_array[i].shape) <= 1:
        spot_data_array[i] = spot_data_array[i][np.newaxis, :]
for i in range(len(spot_label_array)):
    if len(spot_label_array[i].shape) <= 1:
        spot_label_array[i] = spot_label_array[i][np.newaxis, :]
        
spot_data_array = np.concatenate(spot_data_array)
spot_label_array = np.concatenate(spot_label_array)
spot_ids = list()
for ids in spot_id_array:
    spot_ids=spot_ids+ids
spot_id_array = spot_ids


In [17]:
spot_id_array

['TCCACATCGTATATTG-1',
 'TCAACAAAGATAATTC-1',
 'GAGCGCAAATACTCCG-1',
 'AGAGATCTCTAAAGCG-1',
 'GAGCGCTGTTAGGTAA-1',
 'CCATATGGAAACTATA-1',
 'CCTAGTTAGTCGCATG-1',
 'CCGTACCCAAGCGCCA-1',
 'TTATCCGGGATCTATA-1',
 'GGCTGAAATAGCAAAG-1',
 'TCCTCGGGCTGGGCTT-1',
 'ATTCAGTAGCAGGGTC-1',
 'GTCTACTCAATTACAA-1',
 'TTCGACGGGAAGGGCG-1',
 'ATAGGGATATCCTTGA-1',
 'CTTTACCGAATAGTAG-1',
 'GACAAACATATGCAGG-1',
 'TGGGCAATAGTTGGGT-1',
 'CACCTTGCGAAACTCG-1',
 'GTACACTTACCTGAAG-1',
 'AACAGGAAATCGAATA-1',
 'ATGTGGACATCTTGAT-1',
 'CAGTAATCCCTCCCAG-1',
 'AACGTGCGAAAGTCTC-1',
 'TCGGCGTACTGCACAA-1',
 'TCGGGAACGTGCCTAG-1',
 'TACTGCAATCAATTAC-1',
 'CTAGGCGCCCTATCAG-1',
 'CCAGCCTGGACCAATA-1',
 'ATATTCAGTTAAACCT-1',
 'GCATGAGGGACGCGGC-1',
 'GAGGCCCGACTCCGCA-1',
 'TGTTTCGGTACTTCTC-1',
 'ACCAATATGCAAGTTA-1',
 'ACAGAACTGAGAACAA-1',
 'TAGACGAAACGCCAAT-1',
 'TAGCTAAGTCCGGGAG-1',
 'TTAGACGAGTCACCTC-1',
 'GGAGAAGTCATTGGCA-1',
 'TGCCTTGGCCAGGCAA-1',
 'TTGTAAGGCCAGTTGG-1',
 'CATAGTACATTGAGAG-1',
 'TTCGTTCAACGAAGTT-1',
 'ATCTTGACT

In [18]:
spot_data_array.shape

(2518, 3, 20, 20)

In [19]:
spot_id_array[:5]

['TCCACATCGTATATTG-1',
 'TCAACAAAGATAATTC-1',
 'GAGCGCAAATACTCCG-1',
 'AGAGATCTCTAAAGCG-1',
 'GAGCGCTGTTAGGTAA-1']

To build the graph, we need to have the `array_col` and `array_row` for each spot:

In [20]:
# import scanpy as sc
# 
# 
# adata = sc.read("../example_data/example_raw_data/sp.X_norm5e4_log1p.h5ad")
# spot_array_cols = adata.obs.array_col
# spot_array_rows = adata.obs.array_row

In [21]:
# read ST data
import scanpy as sc

# adata = sc.read_visium(path = '../patch/V10T03-282-A1', 
#                        count_file='V10T03-282-A1__filtered_feature_bc_matrix.h5', 
#                        library_id='A1_spot',                        
#                        load_images=True)
adata = sc.read_visium(path = '../patch', 
                       count_file='Visium_FFPE_Human_Breast_Cancer_filtered_feature_bc_matrix.h5', 
                       library_id='A1_spot',                        
                       load_images=True)
adata.var_names_make_unique()
adata.var['SYMBOL'] = adata.var_names

spot_array_cols = adata.obs.array_col
spot_array_rows = adata.obs.array_row


  utils.warn_names_duplicates("var")
  utils.warn_names_duplicates("var")


From the `./example_data/example_raw_data/WSA_LngSP9258467/spot_view.jpg`, we can see that every spot has 6 nearest neighbors, build the graph according to this spatial relation:

In [22]:
import random
spot_array_x_y = []
for item in spot_id_array:
    
    spot_array_x_y.append([int(spot_array_cols[item]), int(spot_array_rows[item])])


In [23]:
max(spot_array_x_y)

[115, 41]

In [24]:
adj = np.zeros((len(spot_array_x_y), len(spot_array_x_y)))

for i in range(len(spot_array_x_y)):
    for j in range(len(spot_array_x_y)):
        if i == j:
            adj[i][j] = 1.0
        else:
            x1 = spot_array_x_y[i][0]
            y1 = spot_array_x_y[i][1]
            x2 = spot_array_x_y[j][0]
            y2 = spot_array_x_y[j][1]

            if x2 <= x1 - 3 or x2 >= x1 + 3 or y2 <= y1 - 2 or y2 >= y1 + 2:
                continue
            else:
                adj[i][j] = 1.0

In [25]:
num = 0
name = []
for x in range(len(adj[0])):
    if adj[0][x] == 1:
     num+=1
     name.append(x)
print(name)

[0, 1302, 1782, 2081, 2203]


For easy visualization, we also save the pixel coordinate of each spot:

In [26]:
spots_coord = pd.read_csv(os.path.join("../patch/output/tissue_positions_list_ordered.csv"), index_col=0)
spots_coord = spots_coord.loc[spot_id_array].values * 0.07285444
spots_coord[:5]

array([[1534.02308864, 1082.10699732],
       [1239.47258772,  680.8247418 ],
       [1203.48249436, 1505.09987596],
       [1000.36431564,  564.76761888],
       [1386.49284764,  617.22281568]])

In [27]:
spots_coord.shape

(2518, 2)

Finally, we can save the processed data as `torch_geometric.data.Data` for `Hist2Cell` trainng:

In [28]:
from torch_geometric.utils import dense_to_sparse
from torch_geometric.data import Data
from torch import Tensor


x = Tensor(spot_data_array)
y = Tensor(spot_label_array)
adj = Tensor(adj)
edge_index, _ = dense_to_sparse(adj)
pos = Tensor(spots_coord)
data = Data(x=x, edge_index=edge_index, y=y, pos=pos)
data

Data(x=[2518, 3, 20, 20], edge_index=[2, 16608], y=[2518, 289], pos=[2518, 2])

In [29]:
edge_index

tensor([[   0,    0,    0,  ..., 2517, 2517, 2517],
        [   0, 1302, 1782,  ..., 2307, 2336, 2517]])

In [30]:
torch.save(data, "../patch/output/tissue_hires_image.pt")