In [1]:
import os
import torch
from torchvision import transforms
import timm
from huggingface_hub import login, hf_hub_download
import scanpy as sc
import pandas as pd
import sys
sys.path.append("./MuCST")

https://www.10xgenomics.com/datasets/human-breast-cancer-block-a-section-1-1-standard-1-1-0

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [15]:
local_dir = "/c/Users/Hao/Dropbox/hw/RuiPath/uni/UNI/assets/ckpts/uni2-h/"
# os.makedirs(local_dir, exist_ok=True)  # create directory if it does not exist
# hf_hub_download("MahmoodLab/UNI2-h", filename="pytorch_model.bin", local_dir=local_dir)
timm_kwargs = {
            'model_name': 'vit_giant_patch14_224',
            'img_size': 224, 
            'patch_size': 14, 
            'depth': 24,
            'num_heads': 24,
            'init_values': 1e-5, 
            'embed_dim': 1536,
            'mlp_ratio': 2.66667*2,
            'num_classes': 0, 
            'no_embed_class': True,
            'mlp_layer': timm.layers.SwiGLUPacked, 
            'act_layer': torch.nn.SiLU, 
            'reg_tokens': 8, 
            'dynamic_img_size': True
        }
model = timm.create_model(
    pretrained=False, **timm_kwargs
)
model.load_state_dict(torch.load("./uni/UNI/assets/ckpts/uni2-h/pytorch_model.bin", map_location=device), strict=True)
transform = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ]
)
model.eval();

In [16]:
model.to(device);

In [None]:
import numpy

In [None]:
dir_input = "SEDR_analyses/data/BRCA1/V1_Human_Breast_Cancer_Block_A_Section_1"

In [36]:
metadata = pd.read_csv("SEDR_analyses/data/BRCA1/metadata.tsv", sep='\t')

In [37]:
metadata

Unnamed: 0,ID,annot_type,fine_annot_type
0,AAACAAGTATCTCCCA-1,Surrounding tumor,Tumor_edge_5
1,AAACACCAATAACTGC-1,Invasive,IDC_4
2,AAACAGAGCGACTCCT-1,Healthy,Healthy_1
3,AAACAGGGTCTATATT-1,Invasive,IDC_3
4,AAACAGTGTTCCTGGG-1,Invasive,IDC_4
...,...,...,...
3793,TTGTTGTGTGTCAAGA-1,Invasive,IDC_7
3794,TTGTTTCACATCCAGG-1,Invasive,IDC_4
3795,TTGTTTCATTAGTCTA-1,Invasive,IDC_4
3796,TTGTTTCCATACAACT-1,Surrounding tumor,Tumor_edge_2


In [None]:
adata = sc.read_10x_h5(f'{dir_input}/filtered_feature_bc_matrix.h5')
adata.var_names_make_unique()

spatial=pd.read_csv(f"{dir_input}/spatial/tissue_positions_list.csv",sep=",",header=None,na_filter=False,index_col=0)

adata.obs["x1"]=spatial[1]
adata.obs["x2"]=spatial[2]
adata.obs["x3"]=spatial[3]
adata.obs["x4"]=spatial[4]
adata.obs["x5"]=spatial[5]

adata=adata[adata.obs["x1"]==1]
adata.var_names=[i.upper() for i in list(adata.var_names)]
adata.var["genename"]=adata.var.index.astype("str")
adata.obs['pred'] = metadata['annot_type'].values

#Read in hitology image
# img=cv2.imread(f"{dir_input}/spatial/full_image.tif")

#Set coordinates
adata.obs["x_array"]=adata.obs["x2"]
adata.obs["y_array"]=adata.obs["x3"]
adata.obs["x_pixel"]=adata.obs["x4"]
adata.obs["y_pixel"]=adata.obs["x5"]
x_array=adata.obs["x_array"].tolist()
y_array=adata.obs["y_array"].tolist()
x_pixel=adata.obs["x_pixel"].tolist()
y_pixel=adata.obs["y_pixel"].tolist()

In [None]:
#Set colors used
plot_color=["#F56867","#FEB915","#C798EE","#59BE86","#7495D3","#D1D1D1","#6D1A9C","#15821E","#3A84E6","#997273","#787878","#DB4C6C","#9E7A7A","#554236","#AF5F3C","#93796C","#F9BD3F","#DAB370","#877F6C","#268785"]
#Plot spatial domains
domains="pred"
num_celltype=len(adata.obs[domains].unique())
adata.uns[domains+"_colors"]=list(plot_color[:num_celltype])
ax=sc.pl.scatter(adata,alpha=1,x="y_pixel",y="x_pixel",color=domains,title=domains,color_map=plot_color,show=False,size=100000/adata.shape[0])
ax.set_aspect('equal', 'box')
ax.axes.invert_yaxis()

In [None]:
adata.obs['fine'] = metadata['fine_annot_type'].values
#Set colors used
plot_color=["#F56867","#FEB915","#C798EE","#59BE86","#7495D3","#D1D1D1","#6D1A9C","#15821E","#3A84E6","#997273","#787878","#DB4C6C","#9E7A7A","#554236","#AF5F3C","#93796C","#F9BD3F","#DAB370","#877F6C","#268785"]
#Plot spatial domains
domains="fine"
num_celltype=len(adata.obs[domains].unique())
adata.uns[domains+"_colors"]=list(plot_color[:num_celltype])
ax=sc.pl.scatter(adata,alpha=1,x="y_pixel",y="x_pixel",color=domains,title=domains,color_map=plot_color,show=False,size=100000/adata.shape[0])
ax.set_aspect('equal', 'box')
ax.axes.invert_yaxis()

In [None]:
import sys 
sys.path.append("../model/") # path to this folder
from load import *
ckpt_path = "./scFoundation/model/models/models.ckpt"
pretrainmodel,pretrainconfig = load_model_frommmf(ckpt_path)

In [None]:
adata.obs['y_pixel'].max() - adata.obs['y_pixel'].min()

In [6]:
from PIL import Image
from tqdm import tqdm
from pathlib import Path
import numpy as np

In [None]:
Image.MAX_IMAGE_PIXELS = None
img = Image.open('./SEDR_analyses/data/BRCA1/V1_Human_Breast_Cancer_Block_A_Section_1/V1_Breast_Cancer_Block_A_Section_1_image.tif')

In [None]:
def image_crop(adata, img, crop_size=224, 
               target_size=224, save_path=None, 
               verbose=False):
    tile_names = []

    with tqdm(total=len(adata), desc='Tiling Image', bar_format='{l_bar}{bar} [ time left: {remaining} ]') as pbar:
        for image_row, image_col in zip(adata.obs['x_pixel'], adata.obs['y_pixel']):
            image_down = image_row - crop_size / 2
            image_up = image_row + crop_size / 2
            image_left = image_col - crop_size / 2
            image_right = image_col + crop_size / 2

            tile = img.crop(
                (image_left, image_down, image_right, image_up)
            )
            tile.thumbnail((target_size, target_size), Image.LANCZOS)
            tile.resize((target_size, target_size))
            tile_name = str(image_col) + '-' + str(image_row) + '-' + str(crop_size)
            if save_path is not None:
                out_tile = Path(save_path) / (tile_name + '.png')
                tile_names.append(str(out_tile))
                if verbose:
                    print('Generating tile at location ({}, {})'.format(str(image_col), str(image_row)))
                tile.save(out_tile, 'PNG')
            pbar.update(1)

    adata.obs['slice_path'] = tile_names
    return adata

In [None]:
# image_crop(adata, img, crop_size=112, target_size=224, save_path="./10x_BRC", verbose=True)

In [18]:
image = Image.open("./10x_BRC/4176-20901-112.png")
image = transform(image).unsqueeze(dim=0) # Image (torch.Tensor) with shape [1, 3, 224, 224] following image resizing and normalization (ImageNet parameters)
with torch.inference_mode():
 feature_emb = model(image.to(device))

In [29]:
img_encodings = {}

# Process all PNG files
directory_path = './10x_BRC/'
for filename in os.listdir(directory_path):
    file_path = os.path.join(directory_path, filename)
        
    try:
        # Load and preprocess image
        image = Image.open(file_path)
        input_tensor = transform(image).unsqueeze(0).to(device) 
        
        # Get encoding
        with torch.no_grad():
            img_encoding = model(input_tensor)
            img_encoding = img_encoding.squeeze()
        
        img_encodings[filename] = img_encoding.detach().cpu().numpy()
        print(f"Encoded {filename}: shape {img_encoding.shape}")
        
    except Exception as e:
        print(f"Error processing {filename}: {e}")
    print(str(len(img_encodings)/len(os.listdir(directory_path))))
    

Encoded 10046-21149-112.png: shape torch.Size([1536])
0.0002632964718272775
Encoded 10047-20199-112.png: shape torch.Size([1536])
0.000526592943654555
Encoded 10047-20674-112.png: shape torch.Size([1536])
0.0007898894154818325
Encoded 10048-19724-112.png: shape torch.Size([1536])
0.00105318588730911
Encoded 10049-19249-112.png: shape torch.Size([1536])
0.0013164823591363876
Encoded 10050-18774-112.png: shape torch.Size([1536])
0.001579778830963665
Encoded 10051-18299-112.png: shape torch.Size([1536])
0.0018430753027909425
Encoded 10056-15448-112.png: shape torch.Size([1536])
0.00210637177461822
Encoded 10057-14973-112.png: shape torch.Size([1536])
0.002369668246445498
Encoded 10058-14498-112.png: shape torch.Size([1536])
0.0026329647182727752
Encoded 10059-14023-112.png: shape torch.Size([1536])
0.0028962611901000527
Encoded 10060-13073-112.png: shape torch.Size([1536])
0.00315955766192733
Encoded 10060-13548-112.png: shape torch.Size([1536])
0.0034228541337546076
Encoded 10061-12598-1

In [30]:
img_encodings_df = pd.DataFrame.from_dict(img_encodings, orient='index')

In [31]:
img_encodings_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,1526,1527,1528,1529,1530,1531,1532,1533,1534,1535
10046-21149-112.png,0.845395,-0.406961,-0.064259,-0.732731,-0.048546,-0.480646,-0.059120,-0.701596,-0.177067,0.821242,...,0.074345,-0.266227,-0.253755,0.046043,0.103824,0.000789,0.398459,0.173685,-0.141150,0.054217
10047-20199-112.png,0.392421,-0.351532,-0.339198,-0.607163,0.027208,-0.372003,-0.139642,-0.527919,-0.245256,0.493132,...,-0.050212,-0.218143,0.002228,0.267319,0.170577,-0.147587,0.517183,-0.048126,-0.185808,-0.018122
10047-20674-112.png,0.282134,-0.504845,-0.013232,-0.608782,-0.049500,-0.326767,-0.080200,-0.582259,0.009425,0.352604,...,0.018401,-0.185991,0.027792,-0.038198,-0.000845,-0.161631,0.576289,-0.069880,-0.138642,-0.016649
10048-19724-112.png,0.149072,-0.168119,-0.151536,-0.723701,-0.186116,-0.556082,0.067825,-0.469686,-0.168141,0.219883,...,0.066718,-0.023171,-0.048679,-0.007728,0.369773,0.027831,0.605612,-0.059576,-0.068349,-0.034603
10049-19249-112.png,0.402921,-0.406703,-0.547809,-0.499998,-0.262068,0.091035,0.041287,-0.698881,-0.375097,0.568460,...,-0.126301,0.249389,0.247691,0.292296,-0.127715,0.252549,0.589073,-0.191899,-0.007843,0.120796
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9935-7134-112.png,0.677207,-0.667800,-0.688258,-0.626875,-0.276892,-0.285960,-0.041374,-0.836034,-0.002152,0.132720,...,-0.181014,-0.235889,0.169988,0.082823,-0.223273,0.124130,0.332062,-0.217668,0.075362,0.239580
9936-6659-112.png,0.868473,-0.643630,-0.372868,-0.723963,-0.160293,-0.774316,-0.059533,-0.714584,0.212742,0.047704,...,-0.444638,0.258513,-0.105204,-0.175431,-0.298183,0.325280,0.650988,-0.129049,0.453918,-0.318257
9937-5708-112.png,0.911758,-0.511487,-0.291454,-0.735612,-0.124204,-0.089685,-0.019259,-0.473784,0.125871,0.089400,...,-0.150862,-0.004110,0.161129,0.309722,0.136660,0.111421,0.609194,-0.230430,0.221502,-0.230389
9937-6183-112.png,0.851856,-0.399228,-0.460762,-0.519031,-0.059516,0.045539,-0.326275,-0.371766,0.182423,0.399901,...,-0.173944,-0.376348,-0.038472,-0.015250,-0.305643,-0.326195,0.112948,0.113722,0.187680,0.054941


In [32]:
from sklearn.cluster import KMeans

In [33]:
kmeans = KMeans(n_clusters=4, random_state=42)
y_pred = kmeans.fit_predict(img_encodings_df)

In [34]:
y_pred

array([1, 1, 1, ..., 1, 2, 1], shape=(3798,), dtype=int32)

In [38]:
from sklearn.metrics import adjusted_rand_score
ari = adjusted_rand_score(y_pred, metadata['annot_type'].values)
print(f"Adjusted Rand Index: {ari:.4f}")

Adjusted Rand Index: -0.0001


In [None]:
feature_emb

In [None]:

adata.obs_names

In [None]:
adata.var

In [None]:
gene_list_df = pd.read_csv('./scFoundation/OS_scRNA_gene_index.19264.tsv', header=0, delimiter='\t')
gene_list = list(gene_list_df['gene_name'])

In [None]:
gexpr_feature = adata.X.toarray()
idx = adata.obs_names.tolist()
col = adata.var['genename'].tolist()
gexpr_feature = pd.DataFrame(gexpr_feature,index=idx,columns=col)

In [None]:
def main_gene_selection(X_df, gene_list):
    """
    Describe:
        rebuild the input adata to select target genes encode protein 
    Parameters:
        adata->`~anndata.AnnData` object: adata with var index_name by gene symbol
        gene_list->list: wanted target gene 
    Returns:
        adata_new->`~anndata.AnnData` object
        to_fill_columns->list: zero padding gene
    """
    to_fill_columns = list(set(gene_list) - set(X_df.columns))
    padding_df = pd.DataFrame(np.zeros((X_df.shape[0], len(to_fill_columns))), 
                              columns=to_fill_columns, 
                              index=X_df.index)
    X_df = pd.DataFrame(np.concatenate([df.values for df in [X_df, padding_df]], axis=1), 
                        index=X_df.index, 
                        columns=list(X_df.columns) + list(padding_df.columns))
    X_df = X_df[gene_list]
    
    var = pd.DataFrame(index=X_df.columns)
    var['mask'] = [1 if i in to_fill_columns else 0 for i in list(var.index)]
    return X_df, to_fill_columns,var

In [None]:
gexpr_feature, _, _ = main_gene_selection(gexpr_feature,gene_list)

In [None]:
gexpr_feature

In [None]:
adata_tmp = sc.AnnData(gexpr_feature)
sc.pp.normalize_total(adata_tmp)
sc.pp.log1p(adata_tmp)
gexpr_feature_normalized = pd.DataFrame(adata_tmp.X,index=adata_tmp.obs_names,columns=adata_tmp.var_names)

In [None]:
gexpr_feature_normalized

In [None]:
sys.path.append("./scFoundation/model/") 
from load import *

In [None]:
tgthighres = 'h5'
geneexpemb=[]

for i in tqdm(range(gexpr_feature.shape[0])):
    with torch.no_grad():
        tmpdata = (gexpr_feature.iloc[i,:]).tolist()
        totalcount = gexpr_feature.iloc[i,:].sum()
        pretrain_gene_x = torch.tensor(tmpdata+[np.log10(totalcount)+float(tgthighres[1:]),np.log10(totalcount)]).unsqueeze(0).cuda()
        value_labels = pretrain_gene_x > 0
        data_gene_ids = torch.arange(19266, device=pretrain_gene_x.device).repeat(pretrain_gene_x.shape[0], 1)
        x, x_padding = gatherData(pretrain_gene_x, value_labels, pretrainconfig['pad_token_id'])
        position_gene_ids, _ = gatherData(data_gene_ids, value_labels, pretrainconfig['pad_token_id'])
        x = pretrainmodel.token_emb(torch.unsqueeze(x, 2).float(), output_weight = 0)
        position_emb = pretrainmodel.pos_emb(position_gene_ids)
        x += position_emb
        geneemb = pretrainmodel.encoder(x,x_padding)

        geneemb1 = geneemb[:,-1,:]
        geneemb2 = geneemb[:,-2,:]
        geneemb3, _ = torch.max(geneemb[:,:-2,:], dim=1)
        geneemb4 = torch.mean(geneemb[:,:-2,:], dim=1)
        geneembmerge = torch.concat([geneemb1,geneemb2,geneemb3,geneemb4],axis=1)
        geneexpemb.append(geneembmerge.detach().cpu().numpy())

In [None]:
gexpr_feature

In [None]:
geneembmerge.shape 