In [None]:
from shapely.geometry import shape, GeometryCollection, Point, Polygon
import os
import json
import pickle
import pandas as pd
import openslide
import cv2
from pathpretrain import load_image
import geojson
import matplotlib.pyplot as plt
import numpy as np
import geopandas as gpd
import scanpy as sc
import torch
from torch_geometric.nn.pool import voxel_grid
import plotly.graph_objects as go
import anndata as ad
import plotly.io as pio
import plotly.express as px
from scipy import stats
from sklearn.metrics import calinski_harabasz_score
from sklearn.metrics import fowlkes_mallows_score

In [None]:
def coarsen(full_xys, full_exp, size):
    voxels=voxel_grid(torch.from_numpy(full_xys),size=(size,size))
    coarse_xys=[]
    coarse_exp=[]
    central_point_indices=[]
    for i in set(voxels.tolist()):
        indices = torch.nonzero(voxels == i, as_tuple=False)
        xys=torch.from_numpy(full_xys[indices]).reshape((-1,2))
        centroid = xys.mean(dim=0)
        distances = torch.norm(xys - centroid, dim=1)
        most_central_point_index = torch.argmin(distances)
        most_central_point = xys[most_central_point_index]
        exps = torch.from_numpy(full_exp[indices]).reshape((-1,1000))
        mean_exp = exps.mean(dim=0)
        coarse_exp.append(mean_exp.numpy().reshape((1,1000)))
        coarse_xys.append(most_central_point.numpy().reshape((1,2)))
        keep_index = indices[most_central_point_index]
        central_point_indices.append(keep_index)
    return coarse_xys, coarse_exp, central_point_indices
def load(patient_id, size):
    warped=pickle.load(open(PATH_TO_WARPED_XYS,'rb'))
    #print(warped)
    mappings_path=PATH_TO_SECTION2PATIENT_MAPPING_TXT
    adata_path=PATH_TO_INFERRED_ST_ANNDATA_SAVED
    metadata_path=PATH_TO_METADATA
    mappings=pd.read_csv(mappings_path)
    slides=list(mappings[mappings['deident']==patient_id]['image_name'].values)
    layers=list(mappings[mappings['deident']==patient_id]['layer'].values)
    sorted_slides=[x for _, x in sorted(zip(layers, slides))]
    slides=sorted_slides
    #print(slides, sorted_slides, layers)
    slides = [s.replace('svs','h5ad') for s in slides]
    all_xys=[]
    all_colors=[]
    adatas=[]
    print(sorted(layers))
    for i in range(len(slides)):
        xys=warped[slides[i]]
        z = np.array([i for q in range(len(xys))])
        z=z.reshape((-1,1))
        #xys = np.concatenate([xys,z],axis=1)
        adata = sc.read(adata_path+slides[i])
        exp = adata.X
        coarse_xys, coarse_exp, keep_indices= coarsen(xys, exp, size)
        keep_indices=[t.item() for t in keep_indices]
        #print(keep_indices)
        sample_id=adata.obs['sample'].iloc[0]
        metadata=pd.read_csv(metadata_path+sample_id+'_metadata.csv')
        array_row=list(metadata['array_row'])
        #print(array_row)
        array_col=list(metadata['array_column'])
        #print(adata.obsm['spatial'].shape, keep_indices[:50])
        new_adata = ad.AnnData(X=adata.X[keep_indices], 
                               obs={'array_col':np.array(array_col)[keep_indices], 'array_row':np.array(array_row)[keep_indices]},
                               obsm={'spatial':adata.obsm['spatial'][keep_indices]},
                               uns={'spatial':adata.uns['spatial']})
        new_adata.var.index = adata.var.index
        adatas.append(new_adata)
    return adatas

In [1]:
#CH SCORING

In [None]:
def domains_3d_stitch(patient, resolution, voxel_size):
    saved_dir=OUTPUTS_OF_STITCH3D
    items=OUTPUT_OF_STITCH_3D_FOR_THIS_PATIENT
    adatas = sorted([item for item in items if 'adata' in item])
    adatas = [sc.read_h5ad(saved_dir+i) for i in adatas]
    print(saved_dir)
    for i in range(len(adatas)):
        adatas[i].obs['section_id']=i
    adatas=ad.concat(adatas)
    latent = pd.read_csv(saved_dir+'/representation.csv', index_col=0)
    adata_all = adatas[latent.index]
    adata_all.obsm['latent'] = np.array(latent.values)
    sc.pp.neighbors(adata_all, use_rep='latent', n_neighbors=30)
    sc.tl.umap(adata_all)
    sc.tl.leiden(adata_all, resolution=resolution)
    return adata_all

def domains_top_slice_norm(patient, resolution, voxel_size):
    adata_path=PATH_TO_NORMAL_ST_SAVED_ANNDATA
    adata_st_list_raw=load(patient, voxel_size)
    raw_2d = adata_st_list_raw[0]
    sc.pp.neighbors(raw_2d, n_neighbors=30, use_rep='X')
    sc.tl.umap(raw_2d)
    sc.tl.leiden(raw_2d, resolution=resolution)
    domains=list(raw_2d.obs['leiden'])
    return raw_2d

In [None]:
#Calculate CH scores
def run_analysis_ch():
    patients=PATIENT_IDS
    resolutions_test=[0.05, 0.15, 0.25, 0.35, 0.45, 0.55, 0.65, 0.75, 0.85, 0.95]
    full_avgs=[]
    raw_avgs=[]
    for resolution in resolutions_test:
        full_scores=[]
        raw_scores=[]
        single_3d_domains=[]
        for patient in patients:
            voxel_size=1500
            full= domains_3d_stitch(patient, resolution, voxel_size)
            full_clusters=[int(v) for v in list(full.obs['leiden'])]
            full_x = full.obsm['latent']
            raw=domains_top_slice_norm(patient, resolution, voxel_size)
            raw_clusters=[int(v) for v in list(raw.obs['leiden'])]
            raw_x = raw.X
            full_score=calinski_harabasz_score(full_x, full_clusters)
            raw_score=calinski_harabasz_score(raw_x, raw_clusters)
            full_scores.append(full_score)
            raw_scores.append(raw_score)
        print(sum(full_scores)/len(full_scores))
        print(sum(raw_scores)/len(raw_scores))
        print('')
        full_avgs.append(sum(full_scores)/len(full_scores))
        raw_avgs.append(sum(raw_scores)/len(raw_scores))
    return full_avgs, raw_avgs

In [2]:
#FM SCORING

In [3]:
def domains_two5_stitch(patient, resolution, voxel_size):
    #Using leiden clustering
    # Return: considering all 3D, and just top slice
    saved_dir=OUTPUTS_OF_STITCH3D
    items=OUTPUT_OF_STITCH_3D_FOR_THIS_PATIENT
    adatas = sorted([item for item in items if 'adata' in item])
    adatas = [sc.read_h5ad(saved_dir+i) for i in adatas]
    adatas_list=[item for item in adatas]
    print(saved_dir)
    for i in range(len(adatas)):
        adatas[i].obs['section_id']=i
    latent = pd.read_csv(saved_dir+'/representation.csv', index_col=0)
    select_slice=0
    print('slice'+str(select_slice))
    index = [item for item in latent.index if 'slice'+str(select_slice) in item]
    latent_indexed = latent[latent.index.isin(index)]
    adata_top=adatas[latent_indexed.index]
    adata_top.obsm['latent'] = np.array(latent_indexed.values)
    sc.pp.neighbors(adata_top, use_rep='latent', n_neighbors=30)
    sc.tl.umap(adata_top)
    sc.tl.leiden(adata_top, resolution=resolution)
    #return set(domains), set(adata_top_domains)
    return adata_top
def gen_annot_labels(patient_id, size):
    warped=pickle.load(open(PATH_TO_WARPED_XYS,'rb'))
    #print(warped)
    mappings_path=PATH_TO_SECTION2PATIENT_MAPPING_TXT
    adata_path=PATH_TO_INFERRED_ST_ANNDATA_SAVED
    metadata_path=PATH_TO_METADATA
    annots_base_path='../outs/warped_annots/'
    mappings=pd.read_csv(mappings_path)
    slides=list(mappings[mappings['deident']==patient_id]['image_name'].values)
    layers=list(mappings[mappings['deident']==patient_id]['layer'].values)
    sorted_slides=[x for _, x in sorted(zip(layers, slides))]
    slides=sorted_slides
    slides = [s.replace('svs','h5ad') for s in slides]
    all_xys=[]
    all_colors=[]
    adatas=[]
    print(sorted(layers))
    all_x=[]
    all_y=[]
    all_z=[]
    for i in range(len(slides)):
        xys=warped[slides[i]]
        print(slides)
        z = np.array([i for q in range(len(xys))])
        adata = sc.read(adata_path+slides[i])
        exp = adata.X
        coarse_xys, coarse_exp, keep_indices= coarsen(xys, exp, size)
        xs=np.array(coarse_xys).reshape((-1,2))[:,0]
        ys=np.array(coarse_xys).reshape((-1,2))[:,1]
        zs=[i for q in range(len(coarse_xys))]
        all_x+=xs.tolist();all_y+=ys.tolist();all_z+=zs
        annots = get_annots_xys(coarse_xys, annots_base_path+slides[i].replace('h5ad','geojson'))
        all_colors+=annots
        keep_indices=[t.item() for t in keep_indices]
        sample_id=adata.obs['sample'].iloc[0]
        metadata=pd.read_csv(metadata_path+sample_id+'_metadata.csv')
        array_row=list(metadata['array_row'])
        array_col=list(metadata['array_column'])
    return all_colors

In [None]:
#Run for FM scoring
def run_analysis_fm():
    patients=PATIENT_IDS
    resolutions_test=RESOLUTIONS_TO_TEST
    raw_avgs=[]
    single_avgs=[]
    for resolution in resolutions_test:
        full_scores=[]
        single_scores=[]
        raw_scores=[]
        single_3d_domains=[]
        for patient in patients[1:]:
            voxel_size=1500
            single_3d= domains_two5_stitch(patient, resolution, voxel_size)
            single_clusters=[int(v) for v in list(single_3d.obs['leiden'])]
            single_x = single_3d.obsm['latent']
            raw=domains_top_slice_norm(patient, resolution, voxel_size)
            raw_clusters=[int(v) for v in list(raw.obs['leiden'])]
            raw_x = raw.X
            annots_true = gen_annot_labels(patient,voxel_size)
            
            single_score=fowlkes_mallows_score(annots_true, single_clusters)
            raw_scores.append(raw_score)
            single_scores.append(single_score)
            print(single_score, raw_score)
        print(sum(single_scores)/len(single_scores))
        print(sum(raw_scores)/len(raw_scores))
        print('')
        raw_avgs.append(sum(raw_scores)/len(raw_scores))
        single_avgs.append(sum(single_scores)/len(single_scores))
    return full_avgs, raw_avgs