In [29]:
from huggingface_hub import hf_hub_download, HfApi, ModelFilter, snapshot_download, login, ModelCard
from dotenv import load_dotenv
import re, json
import os

REPO_ID = 'scvi-tools/MODEL-FOR-UNIT-TESTING-1'
API_TOKEN = os.getenv("API_TOKEN")

# Repository Setup

To use this repository, you need to set up a few things locally. Please follow the steps below:

### 1. Create `.env` File

1. Create a file named `.env` in the root directory of this repository.

2. Open the `.env` file in a text editor.

3. Add your Hugging Face API key to the `.env` file using the following format:

   ```plaintext
   API_TOKEN="your_api_key_here"

### 2. Creata a /local_models directory

1. Create a /local_models directory in the root of the scvi-hub directory. 
This is where the local models from scvi-hub will be stored.


In [4]:
login(token=API_TOKEN)

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

### List the models under a repository

In [5]:
api = HfApi()
models = api.list_models(
    filter=ModelFilter(
        author="scvi-tools",
    )
)

# Get all model IDs that do not have the words demo or testing in the name
filtered_model_ids = [model.modelId for model in models if "demo" not in model.modelId.lower() and "testing" not in model.modelId.lower()]

# Get all the model ids
print(filtered_model_ids)
# Get a certain model id
print(filtered_model_ids[2])


['scvi-tools/human-lung-cell-atlas', 'scvi-tools/tabula-sapiens-bladder-scvi', 'scvi-tools/tabula-sapiens-bladder-scanvi', 'scvi-tools/tabula-sapiens-bladder-condscvi', 'scvi-tools/tabula-sapiens-bladder-stereoscope', 'scvi-tools/tabula-sapiens-blood-scvi', 'scvi-tools/tabula-sapiens-blood-scanvi', 'scvi-tools/tabula-sapiens-blood-condscvi', 'scvi-tools/tabula-sapiens-blood-stereoscope', 'scvi-tools/tabula-sapiens-bone_marrow-scvi', 'scvi-tools/tabula-sapiens-bone_marrow-scanvi', 'scvi-tools/tabula-sapiens-bone_marrow-condscvi', 'scvi-tools/tabula-sapiens-bone_marrow-stereoscope', 'scvi-tools/tabula-sapiens-eye-scvi', 'scvi-tools/tabula-sapiens-eye-scanvi', 'scvi-tools/tabula-sapiens-eye-condscvi', 'scvi-tools/tabula-sapiens-eye-stereoscope', 'scvi-tools/tabula-sapiens-fat-scvi', 'scvi-tools/tabula-sapiens-fat-scanvi', 'scvi-tools/tabula-sapiens-fat-condscvi', 'scvi-tools/tabula-sapiens-fat-stereoscope', 'scvi-tools/tabula-sapiens-heart-scvi', 'scvi-tools/tabula-sapiens-heart-scanvi', 

### Download the models and cache them

The models will be downloaded in the local_models directory. The Pytorch files can be found under the /snapshots folder.

In [8]:
# Download files to local folder and use the cached version
# This example downloads the repository that has model_id[4]
# The model will be downloaded under "/snapshot"
snapshot_download(repo_id=filtered_model_ids[4], allow_patterns=["*.h5ad","*.pt"],cache_dir='C:/Users/Ronald/Desktop/Helmholtz/scvi-hub/local_models')

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

'C:/Users/Ronald/Desktop/Helmholtz/scvi-hub/local_models\\models--scvi-tools--tabula-sapiens-bladder-stereoscope\\snapshots\\4c503c88748da6aa9021489a13d19ef1da0b5506'

### Fetch Necessary Model Metadata

In [35]:
# Get the card of the corresponding model
card = ModelCard.load('scvi-tools/tabula-sapiens-uterus-scanvi')

def parse_json_from_text(text):
    """
    Parses the "model_setup_anndata_args" object from a given text block and returns it as a Python dictionary.
    
    Args:
        text (str): The input text containing the JSON object.
        
    Returns:
        dict or None: If the JSON object is successfully parsed, returns a Python dictionary.
                      If parsing fails or if the JSON object is not found, returns None.
    """
    # Define the pattern to find the "model_setup_anndata_args" JSON object
    pattern = r"\*\*model_setup_anndata_args\*\*:\s*```json\s*(.*?)```"

    # Find the JSON object using regex
    match = re.search(pattern, text, re.DOTALL)

    if match:
        json_str = match.group(1)
        
        # Parse the JSON object
        try:
            json_object = json.loads(json_str)
            return json_object
        except json.JSONDecodeError as e:
            print("Error decoding JSON:", e)
            return None
    else:
        print("JSON object not found in the text.")
        return None

# Print the parsed model_setup_anndata_args object.
print(parse_json_from_text(card.text))


{'labels_key': 'cell_ontology_class', 'unlabeled_category': 'unknown', 'layer': None, 'batch_key': 'donor_assay', 'size_factor_key': None, 'categorical_covariate_keys': None, 'continuous_covariate_keys': None}
