# loading data

## check for data file, download if not exist

In [3]:
import os
import requests
import anndata as ad
import logging
from tqdm import tqdm

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def process_downloaded_data(adata):
    adata.obs['cell_type'] = adata.obs.pop('final_annotation')
    sc.pp.highly_variable_genes(adata, n_top_genes=4000, inplace=True)
    adata = adata[:, adata.var['highly_variable']]
    return adata
    
def load_or_download_anndata(folder_path: str, download_url: str) -> ad.AnnData:
    """
    Check if an AnnData file exists in the specified folder with the proposed name. 
    If not, download it from the given URL. Then load and return the AnnData object.

    Parameters:
    - folder_path (str): Path to the folder where the AnnData file should be located.
    - download_url (str): URL to download the file if it does not exist.

    Returns:
    - anndata.AnnData: Loaded AnnData object.
    """
    # Ensure the folder exists
    os.makedirs(folder_path, exist_ok=True)

    # Extract file name from the URL
    file_name = download_url.split('/')[-1]
    file_path = os.path.join(folder_path, file_name)

    # Check if the file exists, download if missing
    if not os.path.exists(file_path):
        logging.info(f"AnnData file not found in folder: {folder_path}. Downloading...")
        response = requests.get(download_url, stream=True)
        response.raise_for_status()  # Raise an error if the request failed
        
        # Get total file size for the progress bar
        total_size = int(response.headers.get('content-length', 0))
        with open(file_path, "wb") as f, tqdm(
            desc="Downloading",
            total=total_size,
            unit="B",
            unit_scale=True,
            unit_divisor=1024,
        ) as bar:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
                bar.update(len(chunk))
        logging.info(f"File downloaded and saved as {file_path}.")
    else:
        logging.info(f"AnnData file found at {file_path}.")

    # Load and return the AnnData file
    logging.info("Loading AnnData file...")
    adata = ad.read_h5ad(file_path)
    adata = process_downloaded_data(adata)
    logging.info("AnnData file loaded successfully.")
    return adata


In [5]:
# Define file path and download URL
folder_path = "./tmp_data"
file_name = "immune.h5ad"
file_path = os.path.join(folder_path, file_name)
download_url = "https://figshare.com/ndownloader/files/25717328"

# Load or download the AnnData object
adata = load_or_download_anndata(file_path, download_url)
print("AnnData object loaded successfully!")


2024-11-15 11:36:40,308 - INFO - AnnData file not found in folder: ./tmp_data/immune.h5ad. Downloading...
Downloading: 100%|██████████| 1.92G/1.92G [00:33<00:00, 61.9MB/s]
2024-11-15 11:37:17,127 - INFO - File downloaded and saved as ./tmp_data/immune.h5ad/25717328.
2024-11-15 11:37:17,127 - INFO - Loading AnnData file...
2024-11-15 11:37:20,334 - INFO - AnnData file loaded successfully.


AnnData object loaded successfully!


## check for label encoder, generate if not exist

In [12]:
import pickle as pkl
from interpretable_ssl.utils import *

label_encoder_path = './tmp_data/le.pkl'

if os.path.exists(label_encoder_path):
    le = pkl.load(open(label_encoder_path, 'rb'))
else:
    print('fitting label encoder')
    fit_label_encoder(adata, label_encoder_path)

fitting label encoder


## init dataset object

In [13]:
from interpretable_ssl.datasets.immune import *


ds = ImmuneDataset(adata, label_encoder_path)

# train

In [22]:
# change configs, constants/MODEL_DIR to the directory you want to save model and results
# then import trainer


from interpretable_ssl.trainers.swav import *
trainer = SwAV(debug=True, dataset=ds, augmentation_type='community')

INFO - 11/15/24 12:03:14 - 0:02:55 - Starting '__init__' of class 'get train test'
INFO - 11/15/24 12:03:14 - 0:02:55 - Finished '__init__' of class 'get train test' in 0.0526 seconds


In [23]:
trainer.pretrain_epochs = 5

In [24]:
trainer.setup()

INFO - 11/15/24 12:03:23 - 0:00:00 - all_latent: None
                                     augmentation_type: community
                                     base_lr: 4.8
                                     batch_size: 512
                                     cell_type_key: cell_type
                                     checkpoint_freq: 25
                                     condition_key: study
                                     crops_for_assign: [0, 1]
                                     cvae_loss_scaler: 0.0
                                     cvae_reg: 0
                                     dataset: pbmc-immune
                                     dataset_id: pbmc-immune
                                     debug: True
                                     default_values: {'dataset_id': 'pbmc-immune', 'model_name_version': 4, 'num_prototypes': 300, 'hidden_dim': 64, 'latent_dims': 8, 'batch_size': 512, 'fine_tuning_epochs': 0, 'experiment_name': '', 'condition_key': 'study', 'c

Embedding dictionary:
 	Num conditions: [3]
 	Embedding dim: [10]
Encoder Architecture:
	Input Layer in, out and cond: 4000 64 10
	Mean/Var Layer in/out: 64 8
Decoder Architecture:
	First Layer in, out and cond:  8 64 10
	Output Layer in/out:  64 4000 



In [None]:
trainer.train()

VBox(children=(Label(value='0.014 MB of 0.030 MB uploaded\r'), FloatProgress(value=0.47019951709238783, max=1.…

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112246341589424, max=1.0…

100%|██████████| 57/57 [00:03<00:00, 16.88it/s]
INFO - 11/15/24 12:04:13 - 0:00:50 - Starting to build community graph.
INFO - 11/15/24 12:04:13 - 0:00:50 - Starting to build community graph.
INFO - 11/15/24 12:04:13 - 0:00:50 - Starting to build community graph.
INFO - 11/15/24 12:04:13 - 0:00:50 - Starting to build community graph.
INFO - 11/15/24 12:04:13 - 0:00:50 - Running Scanpy neighbors with k=11.
INFO - 11/15/24 12:04:13 - 0:00:50 - Running Scanpy neighbors with k=11.
INFO - 11/15/24 12:04:13 - 0:00:50 - Starting to build community graph.


         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/15/24 12:04:13 - 0:00:50 - Running Scanpy neighbors with k=11.
INFO - 11/15/24 12:04:13 - 0:00:50 - Starting to build community graph.
INFO - 11/15/24 12:04:13 - 0:00:50 - Running Scanpy neighbors with k=11.
INFO - 11/15/24 12:04:13 - 0:00:50 - Starting to build community graph.
INFO - 11/15/24 12:04:13 - 0:00:50 - Starting to build community graph.
INFO - 11/15/24 12:04:13 - 0:00:50 - Running Scanpy neighbors with k=11.
INFO - 11/15/24 12:04:13 - 0:00:50 - Starting to build community graph.
INFO - 11/15/24 12:04:13 - 0:00:50 - Running Scanpy neighbors with k=11.
INFO - 11/15/24 12:04:13 - 0:00:50 - Running Scanpy neighbors with k=11.


         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/15/24 12:04:13 - 0:00:50 - Running Scanpy neighbors with k=11.


         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/15/24 12:04:13 - 0:00:50 - Running Scanpy neighbors with k=11.


         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/15/24 12:04:13 - 0:00:50 - Starting to build community graph.


         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.
         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/15/24 12:04:13 - 0:00:50 - Running Scanpy neighbors with k=11.


         Falling back to preprocessing with `sc.pp.pca` and default params.


INFO - 11/15/24 12:04:33 - 0:01:10 - Performing Leiden community detection.
INFO - 11/15/24 12:04:33 - 0:01:10 - Performing Leiden community detection.
INFO - 11/15/24 12:04:33 - 0:01:10 - Performing Leiden community detection.
INFO - 11/15/24 12:04:33 - 0:01:10 - Performing Leiden community detection.
INFO - 11/15/24 12:04:33 - 0:01:10 - Performing Leiden community detection.
INFO - 11/15/24 12:04:33 - 0:01:10 - Performing Leiden community detection.
INFO - 11/15/24 12:04:34 - 0:01:11 - Performing Leiden community detection.
INFO - 11/15/24 12:04:34 - 0:01:11 - Performing Leiden community detection.
INFO - 11/15/24 12:04:34 - 0:01:11 - Performing Leiden community detection.
INFO - 11/15/24 12:04:35 - 0:01:12 - Performing Leiden community detection.


# evalaute

In [None]:
query_latent = trainer.encode_query()
metric_df = MetricCalculator(
        trainer.query.adata,
        [query_latent],
        save_path=trainer.get_metric_file_path("query"),
    ).calculate()

metric_df