In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
import scanpy as sc
import squidpy as sq
import sys
import json
import time
import os
import shutil
from tqdm import tqdm
import anndata as ad
sys.path.append('../')
from src.preprocess_utils.preprocess_image import get_low_res_image
from src.utils import create_cross_validation_folds
from src.utils import preprocess_adata
pd.set_option('display.max_columns', None)

In [None]:
#dataset_of_interest = ['IDC', 'PRAD', 'PAAD', 'SKCM', 'COAD', 'READ', 'CCRCC', 'HCC', 'LUNG', 'LYMPH_IDC']
dataset_of_interest = ['SCCRCC', 'IDC', 'PAAD', 'COAD']

In [None]:
df = pd.read_csv('data/HEST/metadata/HEST_v1_0_0.csv')
df = df[df.species == 'Homo sapiens']
df = df[~df.patient.isna()]
df = df[df.st_technology.apply(lambda x: 'Visium' in x)]
df = df[np.logical_or(df.oncotree_code != "IDC", np.logical_and(df.oncotree_code == "IDC", df.st_technology == "Xenium"))]
df = df[df.oncotree_code.isin(dataset_of_interest)]
df = df[~df.id.isin(["TENX98", "TENX97"])]
df.head()

In [None]:
df.oncotree_code.value_counts()

In [None]:
df[['oncotree_code', 'patient']].value_counts()

In [None]:
def create_yaml_file(path, YAML_TEXT):
    with open(path, 'w') as f:
            f.write(YAML_TEXT)


def create_folder(newpath):
    if not os.path.exists(newpath):
        os.makedirs(newpath)

In [None]:
base_path = "../"
downsample_factor = 10
dot_size = 10

In [None]:
all_oncotree_code_cluster = []
for oncotree_code in tqdm(df.oncotree_code.unique()):
    subset_hest = df[df.oncotree_code == oncotree_code]
    patient_ids = subset_hest.patient
    patient_ids = patient_ids[~patient_ids.isna()].values
    unique_patient = np.unique(patient_ids)
    
    subset_hest['groups'] = subset_hest.patient.values
    
    print(f"{oncotree_code}: {len(subset_hest)}")
    unique_patients = list(set([p for p in subset_hest['groups']]))
    patient_replicate_pairs = {p:[s for s in subset_hest.query(f'groups == "{p}"').id.values] for p in unique_patients}
    folds = create_cross_validation_folds(patient_replicate_pairs)

    Path(f"{base_path}/{oncotree_code}").mkdir(parents=True, exist_ok=True)
    Path(f"{base_path}/{oncotree_code}/data").mkdir(parents=True, exist_ok=True)
    Path(f"{base_path}/{oncotree_code}/data/h5ad").mkdir(parents=True, exist_ok=True)
    Path(f"{base_path}/{oncotree_code}/data/image").mkdir(parents=True, exist_ok=True)
    Path(f"{base_path}/{oncotree_code}/data/meta").mkdir(parents=True, exist_ok=True)
    
    import yaml
    # remove references
    yaml.Dumper.ignore_aliases = lambda *args : True
    with open(f'{base_path}/{oncotree_code}/cross_validation_config.yaml', 'w+') as ff:
        yaml.dump(folds, ff, default_flow_style=False)
    
    all_adatas = []
    for _, row in subset_hest.iterrows():
        sample_id = row.id
        image_filename = row.image_filename
        spot_diameter = row.spot_diameter
        group = row.groups
        adata_path = f"data/HEST/data/st/{sample_id}.h5ad"

        adata = sc.read_h5ad(adata_path)
        
        if type(adata.X) == np.ndarray:
            pass
        else:
            adata.X = adata.X.toarray()
        empty_spot = adata.X.sum(axis=1) == 0
        adata = adata[~empty_spot,:]
        adata = preprocess_adata(adata, run_dim_red=False)
        adata.obs.index = [f"{i}_{sample_id}_{oncotree_code}" for i in adata.obs.index]
        adata.obs['batch'] = group
        all_adatas.append(adata)

    all_adatas = ad.concat(all_adatas)
    sc.pp.pca(all_adatas)
    sc.external.pp.harmony_integrate(all_adatas, key="batch")
    sc.pp.neighbors(all_adatas, use_rep="X_pca_harmony")
    sc.tl.leiden(all_adatas, resolution=0.2)
    all_oncotree_code_cluster.append(all_adatas.obs)
all_oncotree_code_cluster = pd.concat(all_oncotree_code_cluster)
all_oncotree_code_cluster

In [None]:
all_adatas

In [None]:
def load_TEMPLATE():
    TEMPLATE = """
SAMPLE:
{}

MODEL:
    - LinearRegression
    - Ridge
#    - XGB
    - DeepSpot
    - MLP
    - HisToGene
    - Hist2ST
    - THItoGene
    - BLEEP

IMAGE_FEATURES:
    - inception
    - phikon
    - uni

GENE_SET:
    - GO_Biological_Process_2023 
    - GO_Cellular_Component_2023
    - KEGG_2021_Human
    - MSigDB_Hallmark_2020
    - GO_Molecular_Function_2023
    - Reactome_2022

top_n_genes_to_predict: 5000
genes_to_evaluate_decentile: 20
n_mini_tiles: 4

DATASET: "HEST"
OUT_FOLDER: "out_benchmark"

DOWNSAMPLE_FACTOR: 10
IMAGE_FORMAT: "tif"

known_genes:

"""
    return TEMPLATE

In [None]:
for oncotree_code in tqdm(df.oncotree_code.unique()):
    subset_hest = df[df.oncotree_code == oncotree_code]
    
    patient_ids = subset_hest.patient
    
    sample_string = ""
    for sample in subset_hest.id.values:
        sample_string += f"   - {sample}\n"
    
    text = load_TEMPLATE().format(sample_string)
    create_yaml_file(f'{base_path}/{oncotree_code}/config_dataset.yaml', text)
    
    for _, row in subset_hest.iterrows():
        sample_id = row.id
        image_filename = row.image_filename
        spot_diameter = row.spot_diameter
    
        adata_path = f"data/HEST/data/st/{sample_id}.h5ad"
        adata_out_path = f"{base_path}/{oncotree_code}/data/h5ad/{sample_id}.h5ad"
        image_path = f"data/HEST/data/wsis/{sample_id}.tif"
        image_out_path = f"{base_path}/{oncotree_code}/data/image/{sample_id}.tif"
        json_path = f"data/HEST/data/metadata/{sample_id}.json"
        json_out_path = f"{base_path}/{oncotree_code}/data/meta/{sample_id}.json"
    
        json_info = json.load(open(json_path))
    
        
        adata = sc.read_h5ad(adata_path)
        if type(adata.X) == np.ndarray:
            pass
        else:
            adata.X = adata.X.toarray()
        empty_spot = adata.X.sum(axis=1) == 0
        adata = adata[~empty_spot,:]
        
        cluster_idx = np.array([f"{i}_{sample_id}_{oncotree_code}" for i in adata.obs.index])
        adata.obs['leiden'] = all_oncotree_code_cluster.loc[cluster_idx].leiden.values.astype(str)
        adata.obs['leiden'] = [f"leiden_0.3_{l}" for l in adata.obs['leiden'].values]
        if empty_spot.sum() > 0:
            print(f"Found empty spot in {oncotree_code} - {sample_id} - {empty_spot.sum()}")
        sc.pp.filter_genes(adata, min_counts=1)
        spot_diameter_fullres = adata.uns['spatial']['ST']['scalefactors']['spot_diameter_fullres']
    
        json_info['spot_diameter_fullres'] = spot_diameter_fullres * 0.5 # 20x magnification
        json_info['downsample_factor'] = downsample_factor
        json_info['dot_size'] = dot_size
        
        image = get_low_res_image(image_path, downsample_factor=downsample_factor)
        adata.obs['x_pixel'] = adata.obs['pxl_row_in_fullres'] * 0.5 # 20x magnification
        adata.obs['y_pixel'] = adata.obs['pxl_col_in_fullres'] * 0.5 # 20x magnification
        adata.obs['x_array'] = adata.obs['array_row']
        adata.obs['y_array'] = adata.obs['array_col']
        
        adata.obsm['spatial'] = adata.obs[["pxl_col_in_fullres", "pxl_row_in_fullres"]].values 
        # adjust coordinates to new image dimensions
        adata.obsm['spatial'] = adata.obsm['spatial'] / downsample_factor
        # create 'spatial' entries
        adata.uns['spatial'] = dict()
        adata.uns['spatial']['library_id'] = dict()
        adata.uns['spatial']['library_id']['images'] = dict()
        adata.uns['spatial']['library_id']['images']['hires'] = image

        adata.obs['barcode'] = adata.obs.index
        adata.var["gene_symbol"] = adata.var.index
        
        adata.write_h5ad(adata_out_path)

        image_20x = get_low_res_image(image_path, downsample_factor=2) # 20x magnification
        from PIL import Image
        image_20x = Image.fromarray(image_20x)
        image_20x.save(image_out_path)
        
        #shutil.copy(image_path, image_out_path)
        with open(json_out_path, 'w') as f:
            f.write(json.dumps(json_info))
    
    