In [None]:
from shapely.geometry import shape, GeometryCollection, Point, Polygon
import os
import json
import pickle
import pandas as pd
import openslide
import cv2
from pathpretrain import load_image
import geojson
import matplotlib.pyplot as plt
import numpy as np
import geopandas as gpd
import scanpy as sc
import torch
from torch_geometric.nn.pool import voxel_grid
import plotly.graph_objects as go
import anndata as ad
import plotly.io as pio
import plotly.express as px
from scipy import stats

In [None]:
def coarsen(full_xys, full_exp, size):
    voxels=voxel_grid(torch.from_numpy(full_xys),size=(size,size))
    coarse_xys=[]
    coarse_exp=[]
    central_point_indices=[]
    for i in set(voxels.tolist()):
        indices = torch.nonzero(voxels == i, as_tuple=False)
        xys=torch.from_numpy(full_xys[indices]).reshape((-1,2))
        centroid = xys.mean(dim=0)
        distances = torch.norm(xys - centroid, dim=1)
        most_central_point_index = torch.argmin(distances)
        most_central_point = xys[most_central_point_index]
        exps = torch.from_numpy(full_exp[indices]).reshape((-1,1500))
        mean_exp = exps.mean(dim=0)
        coarse_exp.append(mean_exp.numpy().reshape((1,1500)))
        coarse_xys.append(most_central_point.numpy().reshape((1,2)))
        keep_index = indices[most_central_point_index]
        central_point_indices.append(keep_index)
        #print(most_central_point_index)
    return coarse_xys, coarse_exp, central_point_indices
def load(patient_id, size):
    mappings_path=PATH_TO_SECTION2PATIENT_MAPPING_TXT
    adata_path=PATH_TO_INFERRED_ST_ANNDATA_SAVED
    metadata_path=PATH_TO_METADATA
    mappings=pd.read_csv(mappings_path)
    slides=list(mappings[mappings['deident']==patient_id]['image_name'].values)
    layers=list(mappings[mappings['deident']==patient_id]['layer'].values)
    sorted_slides=[x for _, x in sorted(zip(layers, slides))]
    slides=sorted_slides
    #print(slides, sorted_slides, layers)
    slides = [s.replace('svs','h5ad') for s in slides]
    all_xys=[]
    all_colors=[]
    adatas=[]
    print(sorted(layers))
    for i in range(len(slides)):
        xys=warped[slides[i]]
        z = np.array([i for q in range(len(xys))])
        z=z.reshape((-1,1))
        adata = sc.read(adata_path+slides[i])
        exp = adata.X
        coarse_xys, coarse_exp, keep_indices= coarsen(xys, exp, size)
        keep_indices=[t.item() for t in keep_indices]
        sample_id=adata.obs['sample'].iloc[0]
        metadata=pd.read_csv(metadata_path+sample_id+'_metadata.csv')
        array_row=list(metadata['array_row'])
        #print(array_row)
        array_col=list(metadata['array_column'])
        new_adata = ad.AnnData(X=adata.X[keep_indices], 
                               obs={'array_col':np.array(array_col)[keep_indices], 'array_row':np.array(array_row)[keep_indices]},
                               obsm={'spatial':adata.obsm['spatial'][keep_indices]},
                               uns={'spatial':adata.uns['spatial']})
        new_adata.var.index = adata.var.index
        adatas.append(new_adata)
    return adatas   

In [None]:
def domains_3d_stitch(patient, resolution, voxel_size):
    saved_dir=PATH_TO_STITCH3D_OUTPUTS_PATIET
    items=os.listdir(PATH_TO_STITCH3D_OUTPUTS_PATIET+patient)
    adatas = sorted([item for item in items if 'adata' in item])
    adatas = [sc.read_h5ad(saved_dir+i) for i in adatas]
    for i in range(len(adatas)):
        adatas[i].obs['section_id']=i
    adatas=ad.concat(adatas)
    latent = pd.read_csv(saved_dir+'/representation.csv', index_col=0)
    adata_all = adatas[latent.index]
    adata_all.obsm['latent'] = np.array(latent.values)
    sc.pp.neighbors(adata_all, use_rep='latent', n_neighbors=30)
    sc.tl.umap(adata_all)
    sc.tl.leiden(adata_all, resolution=resolution)
    domains=list(adata_all.obs['leiden'])
    index = [item for item in latent.index if 'slice0' in item]
    latent_indexed = latent[latent.index.isin(index)]
    adata_top=adatas[latent_indexed.index]
    adata_top.obsm['latent'] = np.array(latent_indexed.values)
    sc.pp.neighbors(adata_top, use_rep='latent', n_neighbors=30)
    sc.tl.umap(adata_top)
    sc.tl.leiden(adata_top, resolution=resolution)
    adata_top_domains = list(adata_top.obs['leiden'])
    return set(domains), set(adata_top_domains)

def domains_top_slice_norm(patient, resolution, voxel_size):
    adata_path="/dartfs/rc/nosnapshots/V/VaickusL-nb/EDIT_Students/projects/3d_pathology_spatial_omics/outs/saved_adata_2/"
    adata_st_list_raw=load(patient, voxel_size)
    raw_2d = adata_st_list_raw[0]
    sc.pp.neighbors(raw_2d, n_neighbors=30, use_rep='X')
    sc.tl.umap(raw_2d)
    sc.tl.leiden(raw_2d, resolution=resolution)
    domains=list(raw_2d.obs['leiden'])
    return set(domains)

In [None]:
def mean_confidence_interval(data, confidence=0.95):
    n = len(data)
    mean = np.mean(data)
    sem = stats.sem(data)  # Standard error of the mean
    h = sem * stats.t.ppf((1 + confidence) / 2, n - 1)
    return mean, h

In [None]:
#Run to find average # spatial domains across settings
def run_analysis():
    patients=PATIENT_IDS
    resolutions_test=RESOLUTIONS_TO_TEST
    for resolution in resolutions_test:
        full_domains=[]
        raw_domains=[]
        single_3d_domains=[]
        for patient in patients:
            if patient=='0':
                voxel_size=1000
            else:
                voxel_size=1500
            full, single_3d= domains_3d_stitch(patient, resolution, voxel_size)
            raw=domains_top_slice_norm(patient, resolution, voxel_size)
            full_domains.append(len(full))
            raw_domains.append(len(raw))
            single_3d_domains.append(len(single_3d))
        domain_lists = [full_domains, single_3d_domains, raw_domains]
        domain_names = ["Full Domains", "Single 3D Domains", "Raw Domains"]
        for name, domain in zip(domain_names, domain_lists):
            mean, interval = mean_confidence_interval(domain)
            print(f"{name}:\n  Mean = {mean:.2f}\n  95% CI = [{interval:.2f}]\n")
        print('')