In [1]:
!pip install huggingface_hub
!pip install python-dotenv

from huggingface_hub import hf_hub_download, HfApi, ModelFilter, snapshot_download, login
from dotenv import load_dotenv
import os

REPO_ID = 'scvi-tools/MODEL-FOR-UNIT-TESTING-1'
API_TOKEN = os.getenv("API_TOKEN")

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.2.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


# Repository Setup

To use this repository, you need to set up a few things locally. Please follow the steps below:

### 1. Create `.env` File

1. Create a file named `.env` in the root directory of this repository.

2. Open the `.env` file in a text editor.

3. Add your Hugging Face API key to the `.env` file using the following format:

   ```plaintext
   API_TOKEN="your_api_key_here"

### 2. Creata a /local_models directory

1. Create a /local_models directory in the root of the scvi-hub directory. 
This is where the local models from scvi-hub will be stored.


In [2]:
login(token=API_TOKEN)

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /root/.cache/huggingface/token
Login successful


### List the models under a repository

In [3]:
api = HfApi()
print(api.get_model_tags())
models = api.list_models(
    filter=ModelFilter(
        author="scvi-tools"
    )
)

# Get all model IDs that do not have the words demo or testing in the name
filtered_model_ids = [model.modelId for model in models if "demo" not in model.modelId.lower() and "testing" not in model.modelId.lower() and "stereoscope" not in model.modelId.lower() and "condscvi" not in model.modelId.lower()]

# Get all the model ids
print(filtered_model_ids)
# Get a certain model id
print(filtered_model_ids[1])


Available Attributes or Keys:
 * dataset
 * language
 * library
 * license
 * pipeline_tag

['scvi-tools/human-lung-cell-atlas', 'scvi-tools/tabula-sapiens-bladder-scvi', 'scvi-tools/tabula-sapiens-bladder-scanvi', 'scvi-tools/tabula-sapiens-blood-scvi', 'scvi-tools/tabula-sapiens-blood-scanvi', 'scvi-tools/tabula-sapiens-bone_marrow-scvi', 'scvi-tools/tabula-sapiens-bone_marrow-scanvi', 'scvi-tools/tabula-sapiens-eye-scvi', 'scvi-tools/tabula-sapiens-eye-scanvi', 'scvi-tools/tabula-sapiens-fat-scvi', 'scvi-tools/tabula-sapiens-fat-scanvi', 'scvi-tools/tabula-sapiens-heart-scvi', 'scvi-tools/tabula-sapiens-heart-scanvi', 'scvi-tools/tabula-sapiens-large_intestine-scvi', 'scvi-tools/tabula-sapiens-large_intestine-scanvi', 'scvi-tools/tabula-sapiens-liver-scvi', 'scvi-tools/tabula-sapiens-liver-scanvi', 'scvi-tools/tabula-sapiens-lung-scvi', 'scvi-tools/tabula-sapiens-lung-scanvi', 'scvi-tools/tabula-sapiens-lymph_node-scvi', 'scvi-tools/tabula-sapiens-lymph_node-scanvi', 'scvi-tools/tab

### Download the models and cache them

The models will be downloaded in the local_models directory. The Pytorch files can be found under the /snapshots folder.

In [4]:
# Download files to local folder and use the cached version
# This example downloads the repository that has model_id[4]
# The model will be downloaded under "/snapshot"
model_path = "/app/scvi_hub_models/"

folder_path = snapshot_download(repo_id=filtered_model_ids[1], allow_patterns=["*.h5ad","*.pt","*.json","*.md"],cache_dir=model_path)


import json
import gdown

json_path = folder_path + "/_scvi_required_metadata.json"
f = open(json_path)

data = json.load(f)

training_data_url = data.pop("training_data_url")
model_parent_module = data.pop("model_parent_module")
model_cls_name = data.pop("model_cls_name")
output = folder_path + '/reference_atlas.h5ad'

gdown.download(training_data_url, output)

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Downloading...
From: https://zenodo.org/api/files/fd2c61e6-f4cd-4984-ade0-24d26d9adef6/TS_Bladder_filtered.h5ad
To: /app/scvi_hub_models/models--scvi-tools--tabula-sapiens-bladder-scvi/snapshots/89565a92a1504e9b22d73071c6e014e24b7d32c0/reference_atlas.h5ad
100%|██████████| 1.13G/1.13G [05:48<00:00, 3.24MB/s]


'/app/scvi_hub_models/models--scvi-tools--tabula-sapiens-bladder-scvi/snapshots/89565a92a1504e9b22d73071c6e014e24b7d32c0/reference_atlas.h5ad'

### Start the ML-pipeline

Inject the model and reference atlas in the existing ML-pipeline

In [8]:
import scvi
import scarches as sca
import scanpy

from huggingface_hub import ModelCard
import re

import ast

# Get the card of the corresponding model
card = ModelCard.load(filtered_model_ids[1])

def parse_json_from_text(text):
    """
    Parses the "model_setup_anndata_args" object from a given text block and returns it as a Python dictionary.
    
    Args:
        text (str): The input text containing the JSON object.
        
    Returns:
        dict or None: If the JSON object is successfully parsed, returns a Python dictionary.
                      If parsing fails or if the JSON object is not found, returns None.
    """
    # Define the pattern to find the "model_setup_anndata_args" JSON object
    pattern = r"\*\*model_setup_anndata_args\*\*:\s*```json\s*(.*?)```"

    # Find the JSON object using regex
    match = re.search(pattern, text, re.DOTALL)

    if match:
        json_str = match.group(1)
        
        # Parse the JSON object
        try:
            json_object = json.loads(json_str)
            return json_object
        except json.JSONDecodeError as e:
            print("Error decoding JSON:", e)
            return None
    else:
        print("JSON object not found in the text.")
        return None

# Print the parsed model_setup_anndata_args object.
model_args = parse_json_from_text(card.text)

batch_key = model_args.pop("batch_key")
labels_key = model_args.pop("labels_key")

reference = scanpy.read_h5ad(folder_path + "/reference_atlas.h5ad")
query = scanpy.pp.subsample(reference, 0.1, copy=True)

eval_string = model_parent_module + "." + model_cls_name

try:
    eval_string == "scvi.model.SCVI" or "scvi.model.SCANVI"

    print(eval_string)
except:
    print("Wrong module/class combination fetched from huggingface")

eval(eval_string).prepare_query_anndata(query, folder_path)

model = eval(eval_string).load_query_data(
            query,
            folder_path,
            freeze_dropout=True,
        )

model.train(
    max_epochs=10,
    plan_kwargs=dict(weight_decay=0.0),
    check_val_every_n_epoch=10,
    use_gpu=False
)

combined_adata = query.concatenate(reference, batch_key="bkey")
eval(eval_string).setup_anndata(combined_adata)

combined_adata.obsm["latent_rep"] = model.get_latent_representation(combined_adata)



scvi.model.SCVI
[34mINFO    [0m File                                                                                                      
         [35m/app/scvi_hub_models/models--scvi-tools--tabula-sapiens-bladder-scvi/snapshots/89565a92a1504e9b22d73071c6e[0m
         [35m014e24b7d32c0/[0m[95mmodel.pt[0m already downloaded                                                                 
[34mINFO    [0m Found [1;36m100.0[0m% reference vars in query data.                                                                
[34mINFO    [0m File                                                                                                      
         [35m/app/scvi_hub_models/models--scvi-tools--tabula-sapiens-bladder-scvi/snapshots/89565a92a1504e9b22d73071c6e[0m
         [35m014e24b7d32c0/[0m[95mmodel.pt[0m already downloaded                                                                 


  new_mapping = _make_column_categorical(
  accelerator, lightning_devices, device = parse_device_args(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Epoch 4/10:  30%|███       | 3/10 [00:06<00:14,  2.06s/it, v_num=1, train_loss_step=4.04e+3, train_loss_epoch=3.52e+3]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")

See the tutorial for concat at: https://anndata.readthedocs.io/en/latest/concatenation.html


[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             
