In [None]:
from aestetik.utils.utils_morphology import extract_morphology_embeddings
from aestetik.utils.utils_transcriptomics import preprocess_adata
from aestetik import AESTETIK
AESTETIK.version()

In [None]:
import plotnine as p9
import squidpy as sq
import time
import scanpy as sc
import pandas as pd
import numpy as np
import torch
import json

In [None]:
import logging
# Configure the logging module
logging.basicConfig(level=logging.INFO)  # Set the desired logging level
logging.getLogger("pyvips").setLevel(logging.CRITICAL)

In [None]:
json_path = f"out_ablation/data/meta/B.json"
adata_in = f"out_ablation/data/h5ad/B.h5ad"

In [None]:
n_components = 15
spot_diameter_fullres = json.load(open(json_path))["spot_diameter_fullres"]
dot_size = json.load(open(json_path))["dot_size"]
spot_diameter_fullres, dot_size

In [None]:
adata = sc.read_h5ad(adata_in)
#adata = adata[adata.obs.sample(50000).index,:] # to speed up, we only select 100 spots.
#adata = preprocess_adata(adata)
adata

In [None]:
# we set the transcriptomics modality
adata.obsm["X_pca_transcriptomics"] = adata.obsm["X_pca"][:,0:n_components]
adata.obsm["X_pca_transcriptomics"].shape

In [None]:
# we set the morphology modality
adata.obsm["X_pca_morphology"] = adata.obsm["image"][:,0:n_components]
adata.obsm["X_pca_morphology"].shape

In [None]:
import torch
torch.device(
            "cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
import multiprocessing

def run_time(adata):

        
        parameters = {'morphology_weight': 1.5,
                     'refine_cluster': 0,
                     'window_size': 3,
                     'epochs': 10,
                     'batch_size': 100_000,
                     'clustering_method': "kmeans",
                     'n_jobs': 1,
                    }
        
        model = AESTETIK(adata, 
                     nCluster=adata.obs.ground_truth.unique().size,
                     **parameters)
        start_time = {}
        end_time = {}
        start_time["prepare_input_for_model"] = time.perf_counter()
        model.prepare_input_for_model()
        end_time["prepare_input_for_model"] = time.perf_counter()
    
        start_time["train"] = time.perf_counter()
        model.train()
        end_time["train"] = time.perf_counter()
    
        start_time["compute_spot_representations"] = time.perf_counter()
        model.compute_spot_representations(cluster=True)
        end_time["compute_spot_representations"] = time.perf_counter()
    
        time_prepare_input_for_model = end_time["prepare_input_for_model"] - start_time["prepare_input_for_model"]
        time_train = end_time["train"] - start_time["train"]
        time_compute_spot_representations = end_time["compute_spot_representations"] - start_time["compute_spot_representations"]
        return time_prepare_input_for_model, time_train, time_compute_spot_representations

In [None]:
def measure_time(adata, n_points):
    adata = adata.copy()
    if len(adata) > n_points:
        adata_subsampled = adata[adata.obs.sample(n_points).index,:]
    else:
        adata_subsampled = adata
    print(f"Running with {len(adata_subsampled)} spots...")
    time_prepare_input_for_model, time_train, time_compute_spot_representations = run_time(adata_subsampled)
    
    return time_prepare_input_for_model, time_train, time_compute_spot_representations

In [None]:
time_df = []
for n_points in reversed([
                 1_000,
                 5_000,
                 10_000, 
                 50_000,
                 100_000,
                 200_000,
                 400_000,
                 500_000,
                 1_000_000,
                 10_000_000,
                ]):
    
    time_prepare_input_for_model, time_train, time_compute_spot_representations = measure_time(adata, n_points)
    time_df.append([n_points, time_prepare_input_for_model, time_train, time_compute_spot_representations])

time_df = pd.DataFrame(time_df, columns=["n_points", 
                                         "prepare_input_for_model", 
                                         "train", 
                                         "compute_spot_representations"])
time_df

In [None]:
time_df.to_csv("run_time.csv", index=False)

In [None]:
to_plot = time_df.melt(["n_points"], var_name="function", value_name="seconds")
(p9.ggplot(to_plot, p9.aes("n_points", "seconds", color="function", group="function")) 
 + p9.geom_point() 
 + p9.geom_line(linetype="dashed")
 + p9.theme_bw()
 + p9.scale_x_log10()
)