In [None]:
import os
from tqdm import tqdm
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

import pandas as pd
import numpy as np
from PIL import Image
Image.MAX_IMAGE_PIXELS = None
import scanpy as sc
import anndata
from scipy.spatial import KDTree
import torch

In [None]:

from huggingface_hub import login
# from CONCH.conch.open_clip_custom import create_model_from_pretrained

login("")  # login with your User Access Token, found at https://huggingface.co/settings/tokens

model_UNI = timm.create_model("hf-hub:MahmoodLab/uni", pretrained=True, init_values=1e-5, dynamic_img_size=True)
transform_UNI = create_transform(**resolve_data_config(model_UNI.pretrained_cfg, model=model_UNI))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_UNI.eval()
model_UNI.to(device)


The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: fineGrained).
Your token has been saved to C:\Users\Administrator\.cache\huggingface\token
Login successful


VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
    (norm): Identity()
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (patch_drop): Identity()
  (norm_pre): Identity()
  (blocks): Sequential(
    (0): Block(
      (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1024, out_features=3072, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=1024, out_features=1024, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): LayerScale()
      (drop_path1): Identity()
      (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=1024, out_features=4096, bias=True)
        (act): GELU(approximate='none')
        (drop1): Dropout(p=0.0, inplace=False)
        (norm): 

In [3]:
def get_img_ebd_uni(patch,model = model_UNI,transform = transform_UNI):
    base_width = 224
    patch_resized = patch.resize((base_width, base_width), 
                                     Image.Resampling.LANCZOS)       # [224, 224]
    img_transformed = transform(patch_resized).unsqueeze(0).to(device) 
    with torch.inference_mode():
        feature_emb = model(img_transformed)
    return torch.clone(feature_emb)


In [4]:
data_path = "./hest1k_datasets/Her2st/"
tif_path = data_path + 'wsis/'
st_path = data_path + 'st/'


In [None]:
st_data_list = []
file_name_list = ["SPA" + str(i) for i in range(119,155)]

first = True
for fn in file_name_list:
    st_data = anndata.read_h5ad(st_path + fn + '.h5ad')
    st_data_list.append(st_data)
    if first:
        commen_genes = st_data.var_names
        first = False
        print(fn,st_data.shape)
        continue
    commen_genes = set(commen_genes).intersection(set(st_data.var_names))
    print(fn,st_data.shape,end='\t')

print("length of commen genes: ", len(commen_genes))
commen_genes = sorted(list(commen_genes))
for fni in range(len(file_name_list)):
    st_data = st_data_list[fni].copy()
    st_data_list[fni] = st_data[:,commen_genes].copy()
    print(file_name_list[fni],"  ",st_data_list[fni].shape)
print('Only keep commen genes')



MISC1 (3460, 33538)
MISC2 (3592, 33538)	MISC3 (3673, 33538)	MISC4 (3639, 33538)	MISC5 (4015, 33538)	MISC6 (4110, 33538)	MISC7 (3498, 33538)	MISC8 (3661, 33538)	MISC9 (4634, 33538)	MISC10 (4789, 33538)	MISC11 (4384, 33538)	MISC12 (4226, 33538)	length of commen genes:  33538
MISC1    (3460, 33538)
MISC2    (3592, 33538)
MISC3    (3673, 33538)
MISC4    (3639, 33538)
MISC5    (4015, 33538)
MISC6    (4110, 33538)
MISC7    (3498, 33538)
MISC8    (3661, 33538)
MISC9    (4634, 33538)
MISC10    (4789, 33538)
MISC11    (4384, 33538)
MISC12    (4226, 33538)
Only keep commen genes


In [None]:
os.makedirs(data_path + 'processed_data/', exist_ok=True)
os.makedirs(data_path + 'processed_data/local_ebd/', exist_ok=True)
os.makedirs(data_path + 'processed_data/neighbor_ebd/', exist_ok=True)
os.makedirs(data_path + 'processed_data/global_ebd/', exist_ok=True)
save_path = data_path + 'processed_data/'


In [8]:

for i in range(len(file_name_list)):
    fn = file_name_list[i]
    st_data = st_data_list[i]
    print(fn)
    img = Image.open(tif_path + fn + ".tif")
    x = st_data.obsm["spatial"][:, 0]          # x coordinate in H&E image
    y = st_data.obsm["spatial"][:, 1]          # y coordinate in H&E image
    coords = np.stack([x,y],axis=1)
    tree = KDTree(coords)
    spot_diameter = st_data.uns["spatial"]["ST"]["scalefactors"]["spot_diameter_fullres"]
    print(f" {fn}  |  Spot diameter: ", spot_diameter)  # Spot diameter for Visium
    if spot_diameter < 224:
        radius = 112  # minimum patch size: 224 by 224
    else:
        radius = int(spot_diameter // 2)
    all_local_ebd = None
    global_ebd = []
    neighbor_ebd = []
    all_neighbor_ebd = []
    all_global_ebd = []
    first = True
    for spot_idx in tqdm(range(len(x))):
        patch = img.crop((x[spot_idx] - radius, y[spot_idx] - radius,
                          x[spot_idx] + radius, y[spot_idx] + radius))
        patch_ebd_uni = get_img_ebd_uni(patch)
        # patch_ebd_conch = get_img_embd_conch(patch)
        # patch_ebd  = torch.cat([patch_ebd_uni, patch_ebd_conch], dim=1)
        if first:
            all_local_ebd = patch_ebd_uni
            first = False
            continue
        all_local_ebd = torch.cat((all_local_ebd,patch_ebd_uni),dim=0)
    for j in range(len(coords)):
        _, n_idx = tree.query(coords[j], k=9)
        _, g_idx = tree.query(coords[j], k=49)
        neighbor_ebd.append(all_local_ebd[n_idx].unsqueeze(0))  # [1, 9, D]
        global_ebd.append(all_local_ebd[g_idx].unsqueeze(0))    # [1, 49, D]
        
    all_neighbor_ebd = torch.cat(neighbor_ebd, dim=0)  # [N_i, 9, D]
    all_global_ebd = torch.cat(global_ebd, dim=0)      # [N_i, 49, D]
    print(f' local ebd shape: {all_local_ebd.shape} | neighbor ebd shape: {all_neighbor_ebd.shape} | global ebd shape: {all_global_ebd.shape}')
    if save_path != None:
        torch.save(all_neighbor_ebd.detach().cpu(),save_path + 'neighbor_ebd/' + fn +".pt")
        torch.save(all_local_ebd.detach().cpu(),save_path + 'local_ebd/' + fn +".pt")
        torch.save(all_global_ebd.detach().cpu(),save_path + 'global_ebd/' + fn +".pt")
    print("#" * 20)

MISC1
 MISC1  |  Spot diameter:  75.74621941347247


  x = F.scaled_dot_product_attention(
100%|██████████| 3460/3460 [01:54<00:00, 30.09it/s]


 local ebd shape: torch.Size([3460, 1536]) | neighbor ebd shape: torch.Size([3460, 9, 1536]) | global ebd shape: torch.Size([3460, 49, 1536])
####################
MISC2
 MISC2  |  Spot diameter:  75.74621941347247


100%|██████████| 3592/3592 [01:59<00:00, 29.94it/s]


 local ebd shape: torch.Size([3592, 1536]) | neighbor ebd shape: torch.Size([3592, 9, 1536]) | global ebd shape: torch.Size([3592, 49, 1536])
####################
MISC3
 MISC3  |  Spot diameter:  75.84913072618464


100%|██████████| 3673/3673 [02:08<00:00, 28.57it/s]


 local ebd shape: torch.Size([3673, 1536]) | neighbor ebd shape: torch.Size([3673, 9, 1536]) | global ebd shape: torch.Size([3673, 49, 1536])
####################
MISC4
 MISC4  |  Spot diameter:  75.74621941347245


100%|██████████| 3639/3639 [01:58<00:00, 30.69it/s]


 local ebd shape: torch.Size([3639, 1536]) | neighbor ebd shape: torch.Size([3639, 9, 1536]) | global ebd shape: torch.Size([3639, 49, 1536])
####################
MISC5
 MISC5  |  Spot diameter:  75.64162790401343


100%|██████████| 4015/4015 [02:10<00:00, 30.76it/s]


 local ebd shape: torch.Size([4015, 1536]) | neighbor ebd shape: torch.Size([4015, 9, 1536]) | global ebd shape: torch.Size([4015, 49, 1536])
####################
MISC6
 MISC6  |  Spot diameter:  75.74471614228091


100%|██████████| 4110/4110 [02:13<00:00, 30.84it/s]


 local ebd shape: torch.Size([4110, 1536]) | neighbor ebd shape: torch.Size([4110, 9, 1536]) | global ebd shape: torch.Size([4110, 49, 1536])
####################
MISC7
 MISC7  |  Spot diameter:  75.84766876939577


100%|██████████| 3498/3498 [01:53<00:00, 30.70it/s]


 local ebd shape: torch.Size([3498, 1536]) | neighbor ebd shape: torch.Size([3498, 9, 1536]) | global ebd shape: torch.Size([3498, 49, 1536])
####################
MISC8
 MISC8  |  Spot diameter:  75.7447512351817


100%|██████████| 3661/3661 [01:58<00:00, 30.89it/s]


 local ebd shape: torch.Size([3661, 1536]) | neighbor ebd shape: torch.Size([3661, 9, 1536]) | global ebd shape: torch.Size([3661, 49, 1536])
####################
MISC9
 MISC9  |  Spot diameter:  75.74510080402263


100%|██████████| 4634/4634 [02:30<00:00, 30.73it/s]


 local ebd shape: torch.Size([4634, 1536]) | neighbor ebd shape: torch.Size([4634, 9, 1536]) | global ebd shape: torch.Size([4634, 49, 1536])
####################
MISC10
 MISC10  |  Spot diameter:  75.7447512351817


100%|██████████| 4789/4789 [02:35<00:00, 30.72it/s]


 local ebd shape: torch.Size([4789, 1536]) | neighbor ebd shape: torch.Size([4789, 9, 1536]) | global ebd shape: torch.Size([4789, 49, 1536])
####################
MISC11
 MISC11  |  Spot diameter:  75.95066061878809


100%|██████████| 4384/4384 [02:22<00:00, 30.68it/s]


 local ebd shape: torch.Size([4384, 1536]) | neighbor ebd shape: torch.Size([4384, 9, 1536]) | global ebd shape: torch.Size([4384, 49, 1536])
####################
MISC12
 MISC12  |  Spot diameter:  75.84797789755349


100%|██████████| 4226/4226 [02:16<00:00, 30.87it/s]


 local ebd shape: torch.Size([4226, 1536]) | neighbor ebd shape: torch.Size([4226, 9, 1536]) | global ebd shape: torch.Size([4226, 49, 1536])
####################


In [9]:
union_hvg = set()

for fn_idx in range(len(file_name_list)):
    st_data = st_data_list[fn_idx].copy()
    fn = file_name_list[fn_idx]
    
    sc.pp.filter_cells(st_data, min_genes=1)
    sc.pp.filter_genes(st_data, min_cells=1)
    sc.pp.normalize_total(st_data, inplace=True)
    sc.pp.log1p(st_data)
    sc.pp.highly_variable_genes(st_data, n_top_genes=2000)

    union_hvg = union_hvg.union(set(st_data.var_names[st_data.var["highly_variable"]]))
    print(fn, len(union_hvg))

union_hvg = sorted([gene for gene in union_hvg])
print(len(union_hvg))

MISC1 2000
MISC2 3536
MISC3 4825
MISC4 5936
MISC5 6965
MISC6 7788
MISC7 8726
MISC8 9574
MISC9 10287
MISC10 10883
MISC11 11393
MISC12 11957
11957


In [10]:
# select union_hvg and concat all slides
all_count_df = pd.DataFrame(st_data_list[0][:, union_hvg].X.toarray(), 
                            columns=union_hvg, 
                            index=[file_name_list[0] + "_" + str(i) for i in range(st_data_list[0].shape[0])]).T

for fn_idx in range(0, len(file_name_list)):
    st_data = st_data_list[fn_idx]
    df = pd.DataFrame(st_data[:, union_hvg].X.toarray(), 
                      columns=union_hvg, 
                      index=[file_name_list[fn_idx] + "_" + str(i) for i in range(st_data.shape[0])]).T
    all_count_df = pd.concat([all_count_df, df], axis=1)
    print(file_name_list[fn_idx], st_data.shape, all_count_df.shape)

all_count_df.fillna(0, inplace=True)
all_count_df = all_count_df.T

MISC1 (3460, 33538) (11957, 6920)
MISC2 (3592, 33538) (11957, 10512)
MISC3 (3673, 33538) (11957, 14185)
MISC4 (3639, 33538) (11957, 17824)
MISC5 (4015, 33538) (11957, 21839)
MISC6 (4110, 33538) (11957, 25949)
MISC7 (3498, 33538) (11957, 29447)
MISC8 (3661, 33538) (11957, 33108)
MISC9 (4634, 33538) (11957, 37742)
MISC10 (4789, 33538) (11957, 42531)
MISC11 (4384, 33538) (11957, 46915)
MISC12 (4226, 33538) (11957, 51141)


In [11]:
# order selected genes by mean and std
all_gene_order_by_mean = all_count_df.mean(axis=0).sort_values(ascending=False).index
all_gene_order_by_std = all_count_df.std(axis=0).sort_values(ascending=False).index

In [18]:
# select top intersection of high mean and high variance genes

num_genes = 300 # to make final gene list of length 200

selected_genes = sorted(list(set(all_gene_order_by_mean[:num_genes]).intersection(set(all_gene_order_by_std[:num_genes]))))
print(len(selected_genes))

1000


In [None]:
with open(data_path + "processed_data/selected_gene_list.txt", "w") as f:
    for gene in selected_genes:
        f.write(gene + "\n")

with open(data_path + "processed_data/all_slide_lst.txt", "w") as f:
    for fn in file_name_list:
        f.write(fn + "\n")