# loading data

## check for data file, download if not exist

In [5]:
import os
import requests
import anndata as ad
import logging
from tqdm import tqdm
import scanpy as sc
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def process_downloaded_data(adata):
    adata.obs['cell_type'] = adata.obs.pop('final_annotation')
    sc.pp.highly_variable_genes(adata, n_top_genes=4000, inplace=True)
    adata = adata[:, adata.var['highly_variable']]
    return adata
    
def load_or_download_anndata(folder_path: str, download_url: str) -> ad.AnnData:
    """
    Check if an AnnData file exists in the specified folder with the proposed name. 
    If not, download it from the given URL. Then load and return the AnnData object.

    Parameters:
    - folder_path (str): Path to the folder where the AnnData file should be located.
    - download_url (str): URL to download the file if it does not exist.

    Returns:
    - anndata.AnnData: Loaded AnnData object.
    """
    # Ensure the folder exists
    os.makedirs(folder_path, exist_ok=True)

    # Extract file name from the URL
    file_name = download_url.split('/')[-1]
    file_path = os.path.join(folder_path, file_name)

    # Check if the file exists, download if missing
    if not os.path.exists(file_path):
        logging.info(f"AnnData file not found in folder: {folder_path}. Downloading...")
        response = requests.get(download_url, stream=True)
        response.raise_for_status()  # Raise an error if the request failed
        
        # Get total file size for the progress bar
        total_size = int(response.headers.get('content-length', 0))
        with open(file_path, "wb") as f, tqdm(
            desc="Downloading",
            total=total_size,
            unit="B",
            unit_scale=True,
            unit_divisor=1024,
        ) as bar:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
                bar.update(len(chunk))
        logging.info(f"File downloaded and saved as {file_path}.")
    else:
        logging.info(f"AnnData file found at {file_path}.")

    # Load and return the AnnData file
    logging.info("Loading AnnData file...")
    adata = ad.read_h5ad(file_path)
    adata = process_downloaded_data(adata)
    logging.info("AnnData file loaded successfully.")
    return adata


In [6]:
# Define file path and download URL
folder_path = "./tmp_data"
file_name = "immune.h5ad"
file_path = os.path.join(folder_path, file_name)
download_url = "https://figshare.com/ndownloader/files/25717328"

# Load or download the AnnData object
adata = load_or_download_anndata(file_path, download_url)
print("AnnData object loaded successfully!")


2024-11-24 16:34:29,129 - INFO - AnnData file found at ./tmp_data/immune.h5ad/25717328.
2024-11-24 16:34:29,130 - INFO - Loading AnnData file...
  disp_grouped = df.groupby("mean_bin")["dispersions"]
2024-11-24 16:34:38,830 - INFO - AnnData file loaded successfully.


AnnData object loaded successfully!


## check for label encoder, generate if not exist

In [7]:
import pickle as pkl
from interpretable_ssl.utils import *

label_encoder_path = './tmp_data/le.pkl'

if os.path.exists(label_encoder_path):
    le = pkl.load(open(label_encoder_path, 'rb'))
else:
    print('fitting label encoder')
    fit_label_encoder(adata, label_encoder_path)

## init dataset object

In [8]:
from interpretable_ssl.datasets.immune import *


ds = ImmuneDataset(adata, label_encoder_path)

# train

In [9]:
# change configs, constants/MODEL_DIR to the directory you want to save model and results
# then import trainer


from interpretable_ssl.trainers.swav import *
trainer = SwAV(debug=True, dataset=ds, augmentation_type='community')

2024-11-24 16:35:31,661 - INFO - Global seed set to 0
  new_rank_zero_deprecation(
  return new_rank_zero_deprecation(*args, **kwargs)
 captum (see https://github.com/pytorch/captum).
2024-11-24 16:36:11,377 - INFO - Loading faiss with AVX512 support.
2024-11-24 16:36:11,378 - INFO - Could not load library with AVX512 support due to:
ModuleNotFoundError("No module named 'faiss.swigfaiss_avx512'")
2024-11-24 16:36:11,378 - INFO - Loading faiss with AVX2 support.
2024-11-24 16:36:11,379 - INFO - Could not load library with AVX2 support due to:
ModuleNotFoundError("No module named 'faiss.swigfaiss_avx2'")
2024-11-24 16:36:11,380 - INFO - Loading faiss.
2024-11-24 16:36:12,838 - INFO - Successfully loaded faiss.
2024-11-24 16:36:21,263 - INFO - Starting '__init__' of class 'get train test'
2024-11-24 16:36:21,304 - INFO - Finished '__init__' of class 'get train test' in 0.0404 seconds


In [10]:
trainer.pretraining_epochs = 5

In [11]:
trainer.setup()

INFO - 11/24/24 16:38:27 - 0:00:00 - all_latent: None
                                     augmentation_type: community
                                     base_lr: 4.8
                                     batch_size: 512
                                     cell_type_key: cell_type
                                     checkpoint_freq: 25
                                     condition_key: study
                                     crops_for_assign: [0, 1]
                                     cvae_epochs: 0
                                     cvae_loss_scaler: 0.0
                                     cvae_reg: 0
                                     dataset: pbmc-immune
                                     dataset_id: pbmc-immune
                                     debug: True
                                     decodable_prototypes: 0
                                     default_values: {'dataset_id': 'pbmc-immune', 'model_name_version': 6, 'num_prototypes': 300, 'hidden_dim': 64, 

Embedding dictionary:
 	Num conditions: [3]
 	Embedding dim: [10]
Encoder Architecture:
	Input Layer in, out and cond: 4000 64 10
	Mean/Var Layer in/out: 64 8
Decoder Architecture:
	First Layer in, out and cond:  8 64 10
	Output Layer in/out:  64 4000 



INFO - 11/24/24 16:38:31 - 0:00:04 - SwAVModel(
                                       (scpoli_encoder): scpoli(
                                         (embeddings): ModuleList(
                                           (0): Embedding(3, 10, max_norm=1.0)
                                         )
                                         (encoder): Encoder(
                                           (FC): Sequential(
                                             (L0): CondLayers(
                                               (expr_L): Linear(in_features=4000, out_features=64, bias=True)
                                               (cond_L): Linear(in_features=10, out_features=64, bias=False)
                                             )
                                             (N0): LayerNorm((64,), eps=1e-05, elementwise_affine=False)
                                             (A0): ReLU()
                                             (D0): Dropout(p=0.05, inplace=False)
  

In [None]:
trainer.train()

ERROR - 11/24/24 16:38:48 - 0:00:21 - Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mfatemehs-hashemig[0m. Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 57/57 [00:17<00:00,  3.27it/s]
INFO - 11/24/24 16:40:10 - 0:01:43 - Starting to build community graph.
INFO - 11/24/24 16:40:10 - 0:01:43 - Starting to build community graph.
INFO - 11/24/24 16:40:10 - 0:01:43 - Starting to build community graph.
INFO - 11/24/24 16:40:10 - 0:01:43 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:40:10 - 0:01:43 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:40:10 - 0:01:43 - Starting to build community graph.
INFO - 11/24/24 16:40:10 - 0:01:43 - Running Scanpy neighbors with k=11.


         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:40:10 - 0:01:43 - Starting to build community graph.
INFO - 11/24/24 16:40:10 - 0:01:43 - Running Scanpy neighbors with k=11.


         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:40:10 - 0:01:43 - Starting to build community graph.
INFO - 11/24/24 16:40:10 - 0:01:43 - Starting to build community graph.


         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:40:10 - 0:01:43 - Starting to build community graph.
INFO - 11/24/24 16:40:10 - 0:01:43 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:40:10 - 0:01:43 - Starting to build community graph.
INFO - 11/24/24 16:40:10 - 0:01:43 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:40:10 - 0:01:43 - Starting to build community graph.


         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:40:10 - 0:01:43 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:40:10 - 0:01:43 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:40:10 - 0:01:43 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:40:10 - 0:01:43 - Running Scanpy neighbors with k=11.


         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:40:39 - 0:02:12 - Performing Leiden community detection.
INFO - 11/24/24 16:40:39 - 0:02:12 - Performing Leiden community detection.
INFO - 11/24/24 16:40:39 - 0:02:12 - Performing Leiden community detection.
INFO - 11/24/24 16:40:39 - 0:02:12 - Performing Leiden community detection.
INFO - 11/24/24 16:40:39 - 0:02:12 - Performing Leiden community detection.
INFO - 11/24/24 16:40:39 - 0:02:12 - Performing Leiden community detection.
INFO - 11/24/24 16:40:39 - 0:02:12 - Performing Leiden community detection.
INFO - 11/24/24 16:40:40 - 0:02:13 - Performing Leiden community detection.
INFO - 11/24/24 16:40:40 - 0:02:13 - Performing Leiden community detection.
INFO - 11/24/24 16:40:40 - 0:02:13 - Performing Leiden community detection.
INFO - 11/24/24 16:40:46 - 0:02:19 - Community labels assigned to the cells.
INFO - 11/24/24 16:40:47 - 0:02:20 - Community labels assigned to the cells.
INFO - 11/24/24 16:40:47 - 0:02:20 - Community labels assigned to the cells.
INFO - 11

         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:41:13 - 0:02:46 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:41:13 - 0:02:46 - Starting to build community graph.
INFO - 11/24/24 16:41:13 - 0:02:46 - Starting to build community graph.
INFO - 11/24/24 16:41:13 - 0:02:46 - Running Scanpy neighbors with k=11.


         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:41:13 - 0:02:46 - Starting to build community graph.


         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:41:13 - 0:02:46 - Running Scanpy neighbors with k=11.


         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:41:13 - 0:02:46 - Starting to build community graph.


         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:41:13 - 0:02:46 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:41:13 - 0:02:46 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:41:13 - 0:02:46 - Running Scanpy neighbors with k=11.


         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:41:13 - 0:02:46 - Starting to build community graph.
INFO - 11/24/24 16:41:13 - 0:02:46 - Running Scanpy neighbors with k=11.


         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:41:41 - 0:03:14 - Performing Leiden community detection.
INFO - 11/24/24 16:41:41 - 0:03:14 - Performing Leiden community detection.
INFO - 11/24/24 16:41:42 - 0:03:15 - Performing Leiden community detection.
INFO - 11/24/24 16:41:42 - 0:03:15 - Performing Leiden community detection.
INFO - 11/24/24 16:41:42 - 0:03:15 - Performing Leiden community detection.
INFO - 11/24/24 16:41:42 - 0:03:15 - Performing Leiden community detection.
INFO - 11/24/24 16:41:42 - 0:03:15 - Performing Leiden community detection.
INFO - 11/24/24 16:41:42 - 0:03:15 - Performing Leiden community detection.
INFO - 11/24/24 16:41:42 - 0:03:15 - Performing Leiden community detection.
INFO - 11/24/24 16:41:42 - 0:03:15 - Performing Leiden community detection.
INFO - 11/24/24 16:41:49 - 0:03:22 - Community labels assigned to the cells.
INFO - 11/24/24 16:41:49 - 0:03:22 - Community labels assigned to the cells.
INFO - 11/24/24 16:41:49 - 0:03:22 - Community labels assigned to the cells.
INFO - 11

         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:42:08 - 0:03:41 - Starting to build community graph.
INFO - 11/24/24 16:42:08 - 0:03:41 - Starting to build community graph.
INFO - 11/24/24 16:42:08 - 0:03:41 - Starting to build community graph.
INFO - 11/24/24 16:42:08 - 0:03:41 - Starting to build community graph.
INFO - 11/24/24 16:42:08 - 0:03:41 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:42:08 - 0:03:41 - Starting to build community graph.
INFO - 11/24/24 16:42:08 - 0:03:41 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:42:08 - 0:03:41 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:42:08 - 0:03:41 - Starting to build community graph.
INFO - 11/24/24 16:42:08 - 0:03:41 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:42:08 - 0:03:41 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:42:08 - 0:03:41 - Running Scanpy neighbors with k=11.


         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:42:08 - 0:03:41 - Starting to build community graph.


         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:42:08 - 0:03:41 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:42:08 - 0:03:41 - Running Scanpy neighbors with k=11.


         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:42:32 - 0:04:05 - Performing Leiden community detection.
INFO - 11/24/24 16:42:33 - 0:04:06 - Performing Leiden community detection.
INFO - 11/24/24 16:42:34 - 0:04:07 - Performing Leiden community detection.
INFO - 11/24/24 16:42:35 - 0:04:07 - Performing Leiden community detection.
INFO - 11/24/24 16:42:38 - 0:04:11 - Performing Leiden community detection.
INFO - 11/24/24 16:42:38 - 0:04:11 - Performing Leiden community detection.
INFO - 11/24/24 16:42:38 - 0:04:11 - Performing Leiden community detection.
INFO - 11/24/24 16:42:39 - 0:04:12 - Performing Leiden community detection.
INFO - 11/24/24 16:42:39 - 0:04:12 - Performing Leiden community detection.
INFO - 11/24/24 16:42:39 - 0:04:12 - Performing Leiden community detection.
INFO - 11/24/24 16:42:42 - 0:04:15 - Community labels assigned to the cells.
INFO - 11/24/24 16:42:42 - 0:04:15 - Community labels assigned to the cells.
INFO - 11/24/24 16:42:45 - 0:04:18 - Community labels assigned to the cells.
INFO - 11

         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:43:03 - 0:04:36 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:43:03 - 0:04:36 - Starting to build community graph.


         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:43:03 - 0:04:36 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:43:03 - 0:04:36 - Starting to build community graph.
INFO - 11/24/24 16:43:03 - 0:04:36 - Starting to build community graph.
INFO - 11/24/24 16:43:03 - 0:04:36 - Starting to build community graph.
INFO - 11/24/24 16:43:03 - 0:04:36 - Starting to build community graph.
INFO - 11/24/24 16:43:03 - 0:04:36 - Running Scanpy neighbors with k=11.


         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:43:03 - 0:04:36 - Running Scanpy neighbors with k=11.


         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:43:03 - 0:04:36 - Starting to build community graph.


         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:43:03 - 0:04:36 - Starting to build community graph.
INFO - 11/24/24 16:43:03 - 0:04:36 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:43:03 - 0:04:36 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:43:03 - 0:04:36 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:43:03 - 0:04:36 - Running Scanpy neighbors with k=11.


         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:43:25 - 0:04:58 - Performing Leiden community detection.
INFO - 11/24/24 16:43:26 - 0:04:59 - Performing Leiden community detection.
INFO - 11/24/24 16:43:26 - 0:04:59 - Performing Leiden community detection.
INFO - 11/24/24 16:43:30 - 0:05:03 - Performing Leiden community detection.
INFO - 11/24/24 16:43:33 - 0:05:06 - Performing Leiden community detection.
INFO - 11/24/24 16:43:34 - 0:05:06 - Performing Leiden community detection.
INFO - 11/24/24 16:43:34 - 0:05:07 - Performing Leiden community detection.
INFO - 11/24/24 16:43:34 - 0:05:07 - Performing Leiden community detection.
INFO - 11/24/24 16:43:34 - 0:05:07 - Performing Leiden community detection.
INFO - 11/24/24 16:43:35 - 0:05:08 - Performing Leiden community detection.
INFO - 11/24/24 16:43:37 - 0:05:10 - Community labels assigned to the cells.
INFO - 11/24/24 16:43:38 - 0:05:11 - Community labels assigned to the cells.
INFO - 11/24/24 16:43:38 - 0:05:11 - Community labels assigned to the cells.
INFO - 11

         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:43:58 - 0:05:31 - Starting to build community graph.


         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:43:58 - 0:05:31 - Starting to build community graph.
INFO - 11/24/24 16:43:58 - 0:05:31 - Starting to build community graph.


         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:43:58 - 0:05:31 - Starting to build community graph.
INFO - 11/24/24 16:43:58 - 0:05:31 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:43:58 - 0:05:31 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:43:58 - 0:05:31 - Starting to build community graph.


         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:43:58 - 0:05:31 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:43:58 - 0:05:31 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:43:58 - 0:05:31 - Starting to build community graph.
INFO - 11/24/24 16:43:58 - 0:05:31 - Running Scanpy neighbors with k=11.
INFO - 11/24/24 16:43:58 - 0:05:31 - Starting to build community graph.
INFO - 11/24/24 16:43:58 - 0:05:31 - Running Scanpy neighbors with k=11.


         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:43:58 - 0:05:31 - Running Scanpy neighbors with k=11.


         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/24/24 16:44:26 - 0:05:59 - Performing Leiden community detection.
INFO - 11/24/24 16:44:26 - 0:05:59 - Performing Leiden community detection.
INFO - 11/24/24 16:44:26 - 0:05:59 - Performing Leiden community detection.
INFO - 11/24/24 16:44:27 - 0:06:00 - Performing Leiden community detection.
INFO - 11/24/24 16:44:27 - 0:06:00 - Performing Leiden community detection.
INFO - 11/24/24 16:44:28 - 0:06:01 - Performing Leiden community detection.
INFO - 11/24/24 16:44:28 - 0:06:01 - Performing Leiden community detection.
INFO - 11/24/24 16:44:28 - 0:06:01 - Performing Leiden community detection.
INFO - 11/24/24 16:44:29 - 0:06:02 - Performing Leiden community detection.
INFO - 11/24/24 16:44:29 - 0:06:02 - Performing Leiden community detection.
INFO - 11/24/24 16:44:36 - 0:06:09 - Community labels assigned to the cells.
INFO - 11/24/24 16:44:36 - 0:06:09 - Community labels assigned to the cells.
INFO - 11/24/24 16:44:36 - 0:06:09 - Community labels assigned to the cells.
INFO - 11

# evalaute

In [None]:
query_latent = trainer.encode_query()
metric_df = MetricCalculator(
        trainer.query.adata,
        [query_latent],
        save_path=trainer.get_metric_file_path("query"),
    ).calculate()

metric_df

# debug

In [34]:
model = trainer.get_model()

In [37]:
self = trainer
scpoli_model = self.get_scpoli_model(model)

In [38]:
scpoli_model.condition_encoders

{'study': {'Oetjen': 0, '10X': 1, 'Sun': 2}}

The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.


In [39]:
scpoli_model.conditions_combined_encoder

{'Oetjen': 0, '10X': 1, 'Sun': 2}

In [41]:
query_model = scPoli.load_query_data(
            adata=self.query.adata,
            reference_model=self.get_scpoli(),
            labeled_indices=[],
        )

Embedding dictionary:
 	Num conditions: [5]
 	Embedding dim: [10]
Encoder Architecture:
	Input Layer in, out and cond: 4000 64 10
	Mean/Var Layer in/out: 64 8
Decoder Architecture:
	First Layer in, out and cond:  8 64 10
	Output Layer in/out:  64 4000 



In [44]:
model.set_scpoli_model(query_model.model)

In [46]:
model.to(self.device)
trainer.encode_adata(trainer.query.adata, model)

100%|██████████| 9/9 [00:01<00:00,  7.16it/s]


tensor([[ 0.433, -0.228, -0.155, -0.464,  0.459, -0.111,  0.545,  0.031],
        [ 0.222, -0.323,  0.310, -0.551,  0.622, -0.124,  0.203, -0.056],
        [ 0.304,  0.058, -0.051,  0.521,  0.599, -0.208, -0.036, -0.476],
        [-0.452, -0.218, -0.107,  0.415,  0.597, -0.228,  0.372, -0.137],
        [ 0.521, -0.270, -0.003,  0.341, -0.298, -0.403,  0.292, -0.450],
        [ 0.417, -0.062, -0.093,  0.275,  0.513, -0.547, -0.274,  0.318],
        [-0.381, -0.208, -0.280, -0.058,  0.104, -0.265, -0.465, -0.658],
        ...,
        [-0.740,  0.472, -0.389, -0.115,  0.042,  0.074, -0.064,  0.231],
        [ 0.133, -0.706,  0.147,  0.151,  0.096, -0.508,  0.406, -0.088],
        [-0.094, -0.796, -0.024, -0.200,  0.276, -0.394,  0.231, -0.181],
        [-0.237, -0.066, -0.472,  0.256,  0.480,  0.165,  0.176, -0.602],
        [-0.657, -0.295, -0.278, -0.010, -0.024,  0.453,  0.323,  0.305],
        [-0.257, -0.159, -0.231,  0.126,  0.277,  0.003, -0.609, -0.626],
        [ 0.818, -0.333, 

In [None]:
swav model
swav trainer
init 
test

In [92]:
import importlib

# Reloading the specific modules
import interpretable_ssl.models.swav
import interpretable_ssl.trainers.adaptive_trainer
import interpretable_ssl.trainers.scpoli_trainer
import interpretable_ssl.trainers.swav

importlib.reload(interpretable_ssl.models.swav)
importlib.reload(interpretable_ssl.trainers.scpoli_trainer)
importlib.reload(interpretable_ssl.trainers.adaptive_trainer)
importlib.reload(interpretable_ssl.trainers.swav)

<module 'interpretable_ssl.trainers.swav' from '/ictstr01/home/icb/fatemehs.hashemig/codes/interpretable-ssl/interpretable_ssl/trainers/swav.py'>

In [93]:
from interpretable_ssl.trainers.swav import *
trainer = SwAV(debug=True, dataset=ds, augmentation_type='community')

INFO - 11/17/24 07:31:36 - 0:02:20 - Starting '__init__' of class 'get train test'
INFO - 11/17/24 07:31:36 - 0:02:20 - Finished '__init__' of class 'get train test' in 0.0437 seconds


In [94]:
from interpretable_ssl.models.swav import *

In [95]:
self = trainer
model = SwAVModel(self.latent_dims, self.num_prototypes, self.ref.adata)

Embedding dictionary:
 	Num conditions: [3]
 	Embedding dim: [10]
Encoder Architecture:
	Input Layer in, out and cond: 4000 64 10
	Mean/Var Layer in/out: 64 8
Decoder Architecture:
	First Layer in, out and cond:  8 64 10
	Output Layer in/out:  64 4000 



In [96]:
self = trainer

In [98]:
self.setup()

INFO - 11/17/24 07:31:49 - 0:00:00 - all_latent: None
                                     augmentation_type: community
                                     base_lr: 4.8
                                     batch_size: 512
                                     cell_type_key: cell_type
                                     checkpoint_freq: 25
                                     condition_key: study
                                     crops_for_assign: [0, 1]
                                     cvae_loss_scaler: 0.0
                                     cvae_reg: 0
                                     dataset: pbmc-immune
                                     dataset_id: pbmc-immune
                                     debug: True
                                     default_values: {'dataset_id': 'pbmc-immune', 'model_name_version': 4, 'num_prototypes': 300, 'hidden_dim': 64, 'latent_dims': 8, 'batch_size': 512, 'fine_tuning_epochs': 0, 'experiment_name': '', 'condition_key': 'study', 'c

Embedding dictionary:
 	Num conditions: [3]
 	Embedding dim: [10]
Encoder Architecture:
	Input Layer in, out and cond: 4000 64 10
	Mean/Var Layer in/out: 64 8
Decoder Architecture:
	First Layer in, out and cond:  8 64 10
	Output Layer in/out:  64 4000 



Embedding dictionary:
 	Num conditions: [3]
 	Embedding dim: [10]
Encoder Architecture:
	Input Layer in, out and cond: 4000 64 10
	Mean/Var Layer in/out: 64 8
Decoder Architecture:
	First Layer in, out and cond:  8 64 10
	Output Layer in/out:  64 4000 



SwAVModel(
  (scpoli_encoder): scpoli(
    (embeddings): ModuleList(
      (0): Embedding(3, 10, max_norm=1.0)
    )
    (encoder): Encoder(
      (FC): Sequential(
        (L0): CondLayers(
          (expr_L): Linear(in_features=4000, out_features=64, bias=True)
          (cond_L): Linear(in_features=10, out_features=64, bias=False)
        )
        (N0): LayerNorm((64,), eps=1e-05, elementwise_affine=False)
        (A0): ReLU()
        (D0): Dropout(p=0.05, inplace=False)
      )
      (mean_encoder): Linear(in_features=64, out_features=8, bias=True)
      (log_var_encoder): Linear(in_features=64, out_features=8, bias=True)
    )
    (decoder): Decoder(
      (FirstL): Sequential(
        (L0): CondLayers(
          (expr_L): Linear(in_features=8, out_features=64, bias=False)
          (cond_L): Linear(in_features=10, out_features=64, bias=False)
        )
        (N0): LayerNorm((64,), eps=1e-05, elementwise_affine=False)
        (A0): ReLU()
        (D0): Dropout(p=0.05, inplace=F

In [111]:
self.get_model()

Embedding dictionary:
 	Num conditions: [3]
 	Embedding dim: [10]
Encoder Architecture:
	Input Layer in, out and cond: 4000 64 10
	Mean/Var Layer in/out: 64 8
Decoder Architecture:
	First Layer in, out and cond:  8 64 10
	Output Layer in/out:  64 4000 



SwAVModel(
  (scpoli_encoder): scpoli(
    (embeddings): ModuleList(
      (0): Embedding(3, 10, max_norm=1.0)
    )
    (encoder): Encoder(
      (FC): Sequential(
        (L0): CondLayers(
          (expr_L): Linear(in_features=4000, out_features=64, bias=True)
          (cond_L): Linear(in_features=10, out_features=64, bias=False)
        )
        (N0): LayerNorm((64,), eps=1e-05, elementwise_affine=False)
        (A0): ReLU()
        (D0): Dropout(p=0.05, inplace=False)
      )
      (mean_encoder): Linear(in_features=64, out_features=8, bias=True)
      (log_var_encoder): Linear(in_features=64, out_features=8, bias=True)
    )
    (decoder): Decoder(
      (FirstL): Sequential(
        (L0): CondLayers(
          (expr_L): Linear(in_features=8, out_features=64, bias=False)
          (cond_L): Linear(in_features=10, out_features=64, bias=False)
        )
        (N0): LayerNorm((64,), eps=1e-05, elementwise_affine=False)
        (A0): ReLU()
        (D0): Dropout(p=0.05, inplace=F

In [110]:
self.split_train_data()

In [112]:
self.finetune_ds.adata

View of AnnData object with n_obs × n_vars = 2914 × 4000
    obs: 'batch', 'chemistry', 'data_type', 'dpt_pseudotime', 'mt_frac', 'n_counts', 'n_genes', 'sample_ID', 'size_factors', 'species', 'study', 'tissue', 'cell_type', 'conditions_combined'
    var: 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'hvg'
    layers: 'counts'

In [113]:
q_model = trainer.adapt_ref_model(model, self.finetune_ds.adata)

Embedding dictionary:
 	Num conditions: [3]
 	Embedding dim: [10]
Encoder Architecture:
	Input Layer in, out and cond: 4000 64 10
	Mean/Var Layer in/out: 64 8
Decoder Architecture:
	First Layer in, out and cond:  8 64 10
	Output Layer in/out:  64 4000 



In [114]:
q_model.scpoli_.conditions_

{'study': ['Oetjen', '10X', 'Sun']}

In [107]:
q_model.scpoli_encoder.conditions_combined

['Oetjen', '10X', 'Sun', 'Freytag', 'Villani']

In [108]:
ref_model = self.get_model()

Embedding dictionary:
 	Num conditions: [3]
 	Embedding dim: [10]
Encoder Architecture:
	Input Layer in, out and cond: 4000 64 10
	Mean/Var Layer in/out: 64 8
Decoder Architecture:
	First Layer in, out and cond:  8 64 10
	Output Layer in/out:  64 4000 



In [109]:
ref_model.scpoli_encoder.conditions_combined

['Oetjen', '10X', 'Sun']